/*
 * Decompiled with CFR 0.152.
 */
package nl.sidnlabs.pcap.decoder;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import nl.sidnlabs.pcap.PcapReaderUtil;
import nl.sidnlabs.pcap.SequencePayload;
import nl.sidnlabs.pcap.decoder.DNSDecoder;
import nl.sidnlabs.pcap.decoder.PacketReader;
import nl.sidnlabs.pcap.packet.DNSPacket;
import nl.sidnlabs.pcap.packet.FlowData;
import nl.sidnlabs.pcap.packet.Packet;
import nl.sidnlabs.pcap.packet.TCPFlow;
import nl.sidnlabs.pcap.packet.TcpHandshake;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class TCPDecoder
implements PacketReader {
    private static final Logger log = LogManager.getLogger(TCPDecoder.class);
    private static final int PROTOCOL_HEADER_TCP_SEQ_OFFSET = 4;
    private static final int PROTOCOL_HEADER_TCP_ACK_OFFSET = 8;
    private static final int TCP_HEADER_DATA_OFFSET = 12;
    private static final int PROTOCOL_HEADER_WINDOW_SIZE_OFFSET = 14;
    private static final int TCP_DNS_LENGTH_PREFIX = 2;
    private DNSDecoder dnsDecoder;
    private Map<TCPFlow, FlowData> flows = new HashMap<TCPFlow, FlowData>();
    private Map<TCPFlow, TcpHandshake> handshakes = new HashMap<TCPFlow, TcpHandshake>();
    private Map<TCPFlow, Packet> reassembledPackets = new HashMap<TCPFlow, Packet>();
    private int reqPacketCounter = 0;
    private int rspPacketCounter = 0;

    public TCPDecoder() {
        this(false);
    }

    public TCPDecoder(boolean allowfail) {
        this.dnsDecoder = new DNSDecoder(allowfail);
    }

    @Override
    public Packet reassemble(Packet packet, byte[] packetData) {
        Packet resassembledPacket;
        Packet reassembledPacket;
        boolean hasPayload;
        byte[] packetPayload = this.decode(packet, packetData);
        if (!this.isDNS(packet)) {
            return Packet.NULL;
        }
        if (packet.isTcpFlagRst()) {
            if (log.isDebugEnabled()) {
                log.debug("Connection RESET for src: {} dst: {}", (Object)packet.getSrc(), (Object)packet.getDst());
            }
            TCPFlow flow = packet.getFlow();
            this.handshakes.remove(flow);
            this.flows.remove(flow);
            this.reassembledPackets.remove(flow);
            return Packet.NULL;
        }
        TCPFlow flow = packet.getFlow();
        boolean isServer = packet.getSrcPort() == 53;
        boolean bl = hasPayload = packetPayload.length > 0;
        if (this.handshake(packet, isServer)) {
            return Packet.NULL;
        }
        if (isServer && packet.isTcpFlagAck() && hasPayload && (reassembledPacket = this.reassembledPackets.get(packet.getReverseFlow())) != null && reassembledPacket.getTcpSeq() == packet.getTcpSeq()) {
            reassembledPacket.setTcpRetransmission(true);
            if (log.isDebugEnabled()) {
                log.debug("Ignoring duplicate packet for src: {} dst: {}", (Object)packet.getSrc(), (Object)packet.getDst());
            }
            return Packet.NULL;
        }
        if (!isServer && packet.isTcpFlagAck() && (resassembledPacket = this.reassembledPackets.remove(packet.getFlow())) != null) {
            if (!resassembledPacket.isTcpRetransmission() && resassembledPacket.nextAck() == packet.getTcpAck()) {
                resassembledPacket.setTcpPacketRtt((int)(packet.getTsMilli() - resassembledPacket.getTsMilli()));
            }
            return resassembledPacket;
        }
        FlowData fd = this.flows.get(flow);
        if (fd == null) {
            if (packetPayload.length < 2) {
                return Packet.NULL;
            }
            fd = new FlowData();
            this.flows.put(flow, fd);
        }
        if (hasPayload) {
            SequencePayload sequencePayload = new SequencePayload(packet.getTcpSeq(), packetPayload, System.currentTimeMillis(), flow);
            if (!fd.addPayload(sequencePayload)) {
                return Packet.NULL;
            }
            if (fd.size() == 1) {
                fd.setNextDnsMsgLen(this.dnsMessageLen(packetPayload, 0));
                fd.setBytesAvail(packetPayload.length);
            } else {
                fd.setBytesAvail(fd.getBytesAvail() + packetPayload.length);
            }
        }
        if (packet.isTcpFlagFin() && !this.isNextPayloadAvail(fd)) {
            this.flows.remove(flow);
            return Packet.NULL;
        }
        if (packet.isTcpFlagFin() || packet.isTcpFlagPsh() || this.isNextPayloadAvail(fd)) {
            if (!fd.isNextPayloadAvail()) {
                return Packet.NULL;
            }
            this.flows.remove(flow);
            if (fd != null && fd.size() > 0) {
                byte[] remainder;
                packet.setReassembledTCPFragments(fd.size());
                packetPayload = new byte[fd.getBytesAvail()];
                List<SequencePayload> linkSequencePayloads = this.linkSequencePayloads(fd.getSortedPayloads(), packet, 0);
                fd.setBytesAvail(linkSequencePayloads.stream().mapToInt(s -> s.getBytes().length).sum());
                if (linkSequencePayloads.isEmpty() || !fd.isNextPayloadAvail()) {
                    return Packet.NULL;
                }
                int destPos = 0;
                SequencePayload prev = null;
                for (SequencePayload seqPayload : linkSequencePayloads) {
                    System.arraycopy(seqPayload.getBytes(), 0, packetPayload, destPos, seqPayload.getBytes().length);
                    destPos += seqPayload.getBytes().length;
                    prev = seqPayload;
                }
                TcpHandshake handshake = this.handshakes.remove(flow);
                if (handshake != null && TcpHandshake.HANDSHAKE_STATE.ACK_RECV == handshake.getState()) {
                    packet.setTcpHandshake(handshake);
                }
                if (isServer) {
                    remainder = this.decodeDnsPayload(packet, packetPayload);
                    if (remainder.length > 0 && prev != null) {
                        this.createNewFlowWithRemainer(flow, remainder, prev);
                    }
                    if (packet != Packet.NULL) {
                        ++this.rspPacketCounter;
                        this.reassembledPackets.put(packet.getReverseFlow(), packet);
                    }
                } else {
                    ++this.reqPacketCounter;
                    remainder = this.decodeDnsPayload(packet, packetPayload);
                    if (remainder.length > 0 && prev != null) {
                        this.createNewFlowWithRemainer(flow, remainder, prev);
                    }
                    return packet;
                }
            }
        }
        return Packet.NULL;
    }

    public byte[] decode(Packet packet, byte[] packetData) {
        packet.setSrcPort(PcapReaderUtil.convertShort(packetData, 0));
        packet.setDstPort(PcapReaderUtil.convertShort(packetData, 2));
        int tcpOrUdpHeaderSize = this.getTcpHeaderLength(packetData);
        if (tcpOrUdpHeaderSize == -1) {
            return new byte[0];
        }
        packet.setTcpHeaderLen(tcpOrUdpHeaderSize);
        packet.setTcpSeq(PcapReaderUtil.convertUnsignedInt(packetData, 4));
        packet.setTcpAck(PcapReaderUtil.convertUnsignedInt(packetData, 8));
        int flags = PcapReaderUtil.convertShort(new byte[]{packetData[12], packetData[13]}) & 0x1FF;
        packet.setTcpFlagNs((flags & 0x100) != 0);
        packet.setTcpFlagCwr((flags & 0x80) != 0);
        packet.setTcpFlagEce((flags & 0x40) != 0);
        packet.setTcpFlagUrg((flags & 0x20) != 0);
        packet.setTcpFlagAck((flags & 0x10) != 0);
        packet.setTcpFlagPsh((flags & 8) != 0);
        packet.setTcpFlagRst((flags & 4) != 0);
        packet.setTcpFlagSyn((flags & 2) != 0);
        packet.setTcpFlagFin((flags & 1) != 0);
        packet.setTcpWindowSize(PcapReaderUtil.convertShort(packetData, 14));
        int payloadLength = packetData.length - tcpOrUdpHeaderSize;
        byte[] data = PcapReaderUtil.readPayload(packetData, tcpOrUdpHeaderSize, payloadLength);
        packet.setPayloadLength(payloadLength);
        packet.setLen(packetData.length);
        return data;
    }

    private List<SequencePayload> linkSequencePayloads(List<SequencePayload> seqPayloads, Packet packet, int trycount) {
        if (trycount > 10) {
            return Collections.emptyList();
        }
        SequencePayload prev = null;
        boolean error = false;
        for (SequencePayload seqPayload : seqPayloads) {
            if (prev != null && !seqPayload.linked(prev)) {
                log.warn("Packet src: " + packet.getSrc() + " dst: " + packet.getDst() + " has Broken sequence chain between " + seqPayload + " and " + prev);
                seqPayload.setIgnore(true);
                error = true;
            }
            prev = seqPayload;
        }
        if (error) {
            List<SequencePayload> newList = seqPayloads.stream().filter(p -> !p.isIgnore()).sorted().collect(Collectors.toList());
            return this.linkSequencePayloads(newList, packet, ++trycount);
        }
        return seqPayloads;
    }

    private boolean isNextPayloadAvail(FlowData fd) {
        return fd != null && fd.isNextPayloadAvail();
    }

    private void createNewFlowWithRemainer(TCPFlow flow, byte[] remainder, SequencePayload lastPayload) {
        lastPayload.setBytes(remainder);
        FlowData fd = new FlowData();
        fd.setBytesAvail(remainder.length);
        fd.setNextDnsMsgLen(this.dnsMessageLen(remainder, 0));
        fd.addPayload(lastPayload);
        this.flows.put(flow, fd);
    }

    private boolean handshake(Packet packet, boolean server) {
        TcpHandshake handshake;
        if (!server && packet.isTcpFlagSyn() && !packet.isTcpFlagAck()) {
            if (this.handshakes.containsKey(packet.getFlow())) {
                this.handshakes.remove(packet.getFlow());
                return true;
            }
            TcpHandshake handshake2 = new TcpHandshake(packet.getTcpSeq());
            handshake2.setSynTs(packet.getTsMilli());
            this.handshakes.put(packet.getFlow(), handshake2);
            return true;
        }
        if (server && packet.isTcpFlagSyn() && packet.isTcpFlagAck()) {
            TCPFlow reverseFlow = packet.getReverseFlow();
            TcpHandshake handshake3 = this.handshakes.get(reverseFlow);
            if (handshake3 != null && handshake3.getClientSynSeq() == packet.getTcpAck() - 1L) {
                if (TcpHandshake.HANDSHAKE_STATE.SYN_RECV == handshake3.getState()) {
                    handshake3.setState(TcpHandshake.HANDSHAKE_STATE.SYN_ACK_SENT);
                    handshake3.setServerAckSeq(packet.getTcpAck());
                    handshake3.setServerSynSeq(packet.getTcpSeq());
                } else {
                    this.handshakes.remove(reverseFlow);
                }
            } else if (log.isDebugEnabled()) {
                log.debug("Cannot find handshake for SYN/ACK, maybe a retry?");
            }
            return true;
        }
        if (!server && packet.isTcpFlagAck() && (handshake = this.handshakes.get(packet.getFlow())) != null && TcpHandshake.HANDSHAKE_STATE.SYN_ACK_SENT == handshake.getState() && packet.getTcpAck() - 1L == handshake.getServerSynSeq()) {
            handshake.setAckTs(packet.getTsMilli());
            handshake.setState(TcpHandshake.HANDSHAKE_STATE.ACK_RECV);
            handshake.setClientAckSeq(packet.getTcpSeq());
            if (!packet.isTcpFlagPsh()) {
                return true;
            }
        }
        return false;
    }

    private int getTcpHeaderLength(byte[] packet) {
        if (12 < packet.length) {
            return (packet[12] >> 4 & 0xF) * 4;
        }
        return -1;
    }

    private int dnsMessageLen(byte[] payload, int payloadIndex) {
        if (payload == null || payload.length < 2) {
            if (log.isDebugEnabled()) {
                log.debug("Reading DNS message len from failed failed, only {} bytes available", (Object)payload.length);
            }
            return 0;
        }
        byte[] lenBytes = new byte[2];
        System.arraycopy(payload, payloadIndex, lenBytes, 0, 2);
        return PcapReaderUtil.convertShort(lenBytes);
    }

    private byte[] decodeDnsPayload(Packet packet, byte[] payload) {
        int msgLen;
        for (int payloadIndex = 0; payload.length > 2 && payloadIndex < payload.length; payloadIndex += msgLen) {
            msgLen = this.dnsMessageLen(payload, payloadIndex);
            if (msgLen > 0 && (payloadIndex += 2) + msgLen <= payload.length) {
                byte[] msgBytes = new byte[msgLen];
                System.arraycopy(payload, payloadIndex, msgBytes, 0, msgLen);
                packet = this.dnsDecoder.decode((DNSPacket)packet, msgBytes);
                continue;
            }
            int index = payloadIndex - 2;
            byte[] remainingBytes = new byte[payload.length - index];
            System.arraycopy(payload, index, remainingBytes, 0, remainingBytes.length);
            return remainingBytes;
        }
        if (log.isDebugEnabled() && ((DNSPacket)packet).getMessageCount() > 1) {
            log.debug("multiple msg in TCP stream");
        }
        return new byte[0];
    }

    public void clearCache(int cacheTTL) {
        ArrayList<TCPFlow> expiredList = new ArrayList<TCPFlow>();
        long now = System.currentTimeMillis();
        block0: for (Map.Entry<TCPFlow, FlowData> entry : this.flows.entrySet()) {
            for (SequencePayload sequencePayload : entry.getValue().getPayloads()) {
                if (sequencePayload.getTime() + (long)cacheTTL > now) continue;
                expiredList.add(entry.getKey());
                continue block0;
            }
        }
        log.info("TCP flow cache size: " + this.flows.size());
        log.info("Expired (to be removed) TCP flows: " + expiredList.size());
        expiredList.stream().forEach(s -> this.flows.remove(s));
    }

    public DNSDecoder getDnsDecoder() {
        return this.dnsDecoder;
    }

    public Map<TCPFlow, FlowData> getFlows() {
        return this.flows;
    }

    public Map<TCPFlow, TcpHandshake> getHandshakes() {
        return this.handshakes;
    }

    public Map<TCPFlow, Packet> getReassembledPackets() {
        return this.reassembledPackets;
    }

    public int getReqPacketCounter() {
        return this.reqPacketCounter;
    }

    public int getRspPacketCounter() {
        return this.rspPacketCounter;
    }

    public void setDnsDecoder(DNSDecoder dnsDecoder) {
        this.dnsDecoder = dnsDecoder;
    }

    public void setFlows(Map<TCPFlow, FlowData> flows) {
        this.flows = flows;
    }

    public void setHandshakes(Map<TCPFlow, TcpHandshake> handshakes) {
        this.handshakes = handshakes;
    }

    public void setReassembledPackets(Map<TCPFlow, Packet> reassembledPackets) {
        this.reassembledPackets = reassembledPackets;
    }

    public void setReqPacketCounter(int reqPacketCounter) {
        this.reqPacketCounter = reqPacketCounter;
    }

    public void setRspPacketCounter(int rspPacketCounter) {
        this.rspPacketCounter = rspPacketCounter;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof TCPDecoder)) {
            return false;
        }
        TCPDecoder other = (TCPDecoder)o;
        if (!other.canEqual(this)) {
            return false;
        }
        DNSDecoder this$dnsDecoder = this.getDnsDecoder();
        DNSDecoder other$dnsDecoder = other.getDnsDecoder();
        if (this$dnsDecoder == null ? other$dnsDecoder != null : !((Object)this$dnsDecoder).equals(other$dnsDecoder)) {
            return false;
        }
        Map<TCPFlow, FlowData> this$flows = this.getFlows();
        Map<TCPFlow, FlowData> other$flows = other.getFlows();
        if (this$flows == null ? other$flows != null : !((Object)this$flows).equals(other$flows)) {
            return false;
        }
        Map<TCPFlow, TcpHandshake> this$handshakes = this.getHandshakes();
        Map<TCPFlow, TcpHandshake> other$handshakes = other.getHandshakes();
        if (this$handshakes == null ? other$handshakes != null : !((Object)this$handshakes).equals(other$handshakes)) {
            return false;
        }
        Map<TCPFlow, Packet> this$reassembledPackets = this.getReassembledPackets();
        Map<TCPFlow, Packet> other$reassembledPackets = other.getReassembledPackets();
        if (this$reassembledPackets == null ? other$reassembledPackets != null : !((Object)this$reassembledPackets).equals(other$reassembledPackets)) {
            return false;
        }
        if (this.getReqPacketCounter() != other.getReqPacketCounter()) {
            return false;
        }
        return this.getRspPacketCounter() == other.getRspPacketCounter();
    }

    protected boolean canEqual(Object other) {
        return other instanceof TCPDecoder;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        DNSDecoder $dnsDecoder = this.getDnsDecoder();
        result = result * 59 + ($dnsDecoder == null ? 43 : ((Object)$dnsDecoder).hashCode());
        Map<TCPFlow, FlowData> $flows = this.getFlows();
        result = result * 59 + ($flows == null ? 43 : ((Object)$flows).hashCode());
        Map<TCPFlow, TcpHandshake> $handshakes = this.getHandshakes();
        result = result * 59 + ($handshakes == null ? 43 : ((Object)$handshakes).hashCode());
        Map<TCPFlow, Packet> $reassembledPackets = this.getReassembledPackets();
        result = result * 59 + ($reassembledPackets == null ? 43 : ((Object)$reassembledPackets).hashCode());
        result = result * 59 + this.getReqPacketCounter();
        result = result * 59 + this.getRspPacketCounter();
        return result;
    }

    public String toString() {
        return "TCPDecoder(dnsDecoder=" + this.getDnsDecoder() + ", flows=" + this.getFlows() + ", handshakes=" + this.getHandshakes() + ", reassembledPackets=" + this.getReassembledPackets() + ", reqPacketCounter=" + this.getReqPacketCounter() + ", rspPacketCounter=" + this.getRspPacketCounter() + ")";
    }
}

