/*
 * Decompiled with CFR 0.152.
 */
package nl.sidnlabs.pcap.decoder;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import nl.sidnlabs.pcap.PcapReaderUtil;
import nl.sidnlabs.pcap.decoder.ChainBuffer;
import nl.sidnlabs.pcap.decoder.DNSDecoder;
import nl.sidnlabs.pcap.decoder.Decoder;
import nl.sidnlabs.pcap.packet.DNSPacket;
import nl.sidnlabs.pcap.packet.FlowData;
import nl.sidnlabs.pcap.packet.Packet;
import nl.sidnlabs.pcap.packet.SequencePayload;
import nl.sidnlabs.pcap.packet.TCPFlow;
import nl.sidnlabs.pcap.packet.TcpHandshake;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class TCPDecoder
implements Decoder {
    private static final Logger log = LogManager.getLogger(TCPDecoder.class);
    private static final int PROTOCOL_HEADER_TCP_SEQ_OFFSET = 4;
    private static final int PROTOCOL_HEADER_TCP_ACK_OFFSET = 8;
    private static final int TCP_HEADER_DATA_OFFSET = 12;
    private static final int PROTOCOL_HEADER_WINDOW_SIZE_OFFSET = 14;
    private static final int PROTOCOL_HEADER_OPTIONS_OFFSET = 20;
    private static final int PROTOCOL_HEADER_OPTION_LEN_MASK = 31;
    private static final int PROTOCOL_HEADER_OPTION_TIMESTAMP = 8;
    private static final int TCP_DNS_LENGTH_PREFIX = 2;
    private DNSDecoder dnsDecoder;
    private Map<TCPFlow, FlowData> flows = new HashMap<TCPFlow, FlowData>();
    private Map<TCPFlow, TcpHandshake> handshakes = new HashMap<TCPFlow, TcpHandshake>();
    private int packetCounter = 0;
    private int reqPacketCounter = 0;
    private int rspPacketCounter = 0;
    private int dnsRspMsgCounter = 0;
    private int dnsReqMsgCounter = 0;
    private ByteBuffer packetPayload = ByteBuffer.allocate(2048);
    private byte[] sharedDnsBuffer = new byte[2048];
    private long lastPacketTs = 0L;

    public TCPDecoder(DNSDecoder dnsDecoder) {
        this.dnsDecoder = dnsDecoder;
    }

    private FlowData removeFlow(TCPFlow flow) {
        return this.flows.remove(flow);
    }

    private void addFlow(TCPFlow flow, FlowData fd) {
        this.flows.put(flow, fd);
    }

    private FlowData lookupFlow(TCPFlow flow) {
        return this.flows.get(flow);
    }

    @Override
    public Packet reassemble(Packet packet, byte[] packetData) {
        boolean isServer;
        ++this.packetCounter;
        if (log.isDebugEnabled()) {
            log.debug("Received {} packets", (Object)this.packetCounter);
        }
        this.packetPayload = this.decode(packet, packetData);
        this.lastPacketTs = packet.getTsMilli();
        if (!this.isDNS(packet)) {
            return Packet.NULL;
        }
        if (packet.isTcpFlagRst()) {
            if (log.isDebugEnabled()) {
                log.debug("Connection RESET for src: {} dst: {}", (Object)packet.getSrc(), (Object)packet.getDst());
            }
            TCPFlow flow = packet.getFlow();
            this.handshakes.remove(flow);
            this.removeFlow(flow);
            return Packet.NULL;
        }
        TCPFlow flow = packet.getFlow();
        boolean bl = isServer = packet.getSrcPort() == 53;
        if (this.handshake(packet, isServer) || packet.isTcpFlagAck() && !this.packetPayload.hasRemaining()) {
            return Packet.NULL;
        }
        FlowData fd = this.lookupFlow(flow);
        if (fd == null) {
            if (this.packetPayload.remaining() < 2) {
                return Packet.NULL;
            }
            fd = new FlowData();
            this.addFlow(flow, fd);
        }
        if (this.packetPayload.hasRemaining()) {
            byte[] bytes = new byte[this.packetPayload.limit()];
            this.packetPayload.get(bytes);
            SequencePayload sequencePayload = new SequencePayload(packet.getTcpSeq(), bytes, packet.getTsMilli(), flow);
            if (log.isDebugEnabled()) {
                log.debug("reassemble, tcp bytes len: {}", (Object)this.packetPayload.limit());
            }
            fd.addPayload(sequencePayload);
        }
        if (packet.isTcpFlagFin() && fd != null && !fd.isMinPayloadAvail()) {
            this.removeFlow(flow);
            return Packet.NULL;
        }
        if (packet.isTcpFlagFin() || packet.isTcpFlagPsh() || packet.isTcpFlagAck()) {
            if (!fd.isMinPayloadAvail()) {
                return Packet.NULL;
            }
            this.removeFlow(flow);
            if (fd != null && fd.size() > 0) {
                packet.setReassembledTCPFragments(fd.size());
                List<SequencePayload> sequencePayloads = this.getFlowPayloadsAsList(fd, packet);
                if (sequencePayloads.isEmpty() || !fd.isMinPayloadAvail()) {
                    return Packet.NULL;
                }
                ChainBuffer combinedBuffer = this.mergePayloads(sequencePayloads);
                SequencePayload lastSequencePayload = sequencePayloads.get(sequencePayloads.size() - 1);
                TcpHandshake handshake = this.handshakes.remove(flow);
                if (handshake != null && TcpHandshake.HANDSHAKE_STATE.ACK_RECV == handshake.getState()) {
                    packet.setTcpHandshakeRTT(handshake.rtt());
                }
                if (isServer) {
                    ChainBuffer remainder = this.decodeDnsPayload(packet, combinedBuffer);
                    this.dnsRspMsgCounter += ((DNSPacket)packet).getMessageCount();
                    if (remainder.readableBytes() > 0 && lastSequencePayload != null) {
                        this.createNewFlowWithRemainder(flow, remainder, lastSequencePayload);
                    }
                    if (packet != Packet.NULL) {
                        ++this.rspPacketCounter;
                        if (log.isDebugEnabled()) {
                            log.debug("Reassembled packet with {} DNS messages", (Object)((DNSPacket)packet).getMessageCount());
                        }
                        return packet;
                    }
                } else {
                    ++this.reqPacketCounter;
                    ChainBuffer remainder = this.decodeDnsPayload(packet, combinedBuffer);
                    this.dnsReqMsgCounter += ((DNSPacket)packet).getMessageCount();
                    if (remainder.readableBytes() > 0 && lastSequencePayload != null) {
                        this.createNewFlowWithRemainder(flow, remainder, lastSequencePayload);
                    }
                    return packet;
                }
            }
        }
        return Packet.NULL;
    }

    private List<SequencePayload> getFlowPayloadsAsList(FlowData fd, Packet packet) {
        if (fd.size() == 1) {
            return fd.getPayloads();
        }
        return this.linkSequencePayloads(fd.getSortedPayloads(), packet);
    }

    private ChainBuffer mergePayloads(List<SequencePayload> sequencePayloads) {
        ChainBuffer buff = null;
        for (SequencePayload sp : sequencePayloads) {
            if (sp.hasBuffer()) {
                buff = sp.getBuffer();
                continue;
            }
            if (buff == null) {
                buff = new ChainBuffer();
            }
            buff.addLast(sp.getBytes());
        }
        return buff;
    }

    public ByteBuffer decode(Packet packet, byte[] packetData) {
        packet.setSrcPort(PcapReaderUtil.convertShort(packetData, 0));
        packet.setDstPort(PcapReaderUtil.convertShort(packetData, 2));
        int tcpOrUdpHeaderSize = this.getTcpHeaderLength(packetData);
        if (tcpOrUdpHeaderSize == -1) {
            return ByteBuffer.allocate(0);
        }
        packet.setTcpHeaderLen(tcpOrUdpHeaderSize);
        packet.setTcpSeq(PcapReaderUtil.convertUnsignedInt(packetData, 4));
        packet.setTcpAck(PcapReaderUtil.convertUnsignedInt(packetData, 8));
        int flags = PcapReaderUtil.convertShort(new byte[]{packetData[12], packetData[13]}) & 0x1FF;
        packet.setTcpFlagNs((flags & 0x100) != 0);
        packet.setTcpFlagCwr((flags & 0x80) != 0);
        packet.setTcpFlagEce((flags & 0x40) != 0);
        packet.setTcpFlagUrg((flags & 0x20) != 0);
        packet.setTcpFlagAck((flags & 0x10) != 0);
        packet.setTcpFlagPsh((flags & 8) != 0);
        packet.setTcpFlagRst((flags & 4) != 0);
        packet.setTcpFlagSyn((flags & 2) != 0);
        packet.setTcpFlagFin((flags & 1) != 0);
        packet.setTcpWindowSize(PcapReaderUtil.convertShort(packetData, 14));
        int payloadLength = packetData.length - tcpOrUdpHeaderSize;
        packet.setPayloadLength(payloadLength);
        packet.setLen(packetData.length);
        return PcapReaderUtil.readPayloadToBuffer(packetData, tcpOrUdpHeaderSize, payloadLength, this.packetPayload);
    }

    public long readUnsignedInt(byte[] buf) {
        int byte1 = 0xFF & buf[0];
        int byte2 = 0xFF & buf[1];
        int byte3 = 0xFF & buf[2];
        int byte4 = 0xFF & buf[3];
        return (long)(byte1 << 24 | byte2 << 16 | byte3 << 8 | byte4) & 0xFFFFFFFFL;
    }

    private List<SequencePayload> linkSequencePayloads(List<SequencePayload> seqPayloads, Packet packet) {
        SequencePayload prev = null;
        for (SequencePayload seqPayload : seqPayloads) {
            if (prev != null && !seqPayload.linked(prev)) {
                if (log.isDebugEnabled()) {
                    log.debug("Packet src: " + packet.getSrc() + " dst: " + packet.getDst() + " has Broken sequence chain between " + String.valueOf(seqPayload) + " and " + String.valueOf(prev));
                }
                return Collections.emptyList();
            }
            prev = seqPayload;
        }
        return seqPayloads;
    }

    private void createNewFlowWithRemainder(TCPFlow flow, ChainBuffer remainder, SequencePayload lastPayload) {
        remainder.clean();
        lastPayload.setBuffer(remainder);
        FlowData fd = new FlowData();
        fd.addPayload(lastPayload);
        this.addFlow(flow, fd);
    }

    private boolean handshake(Packet packet, boolean server) {
        TcpHandshake handshake;
        if (!server && packet.isTcpFlagSyn() && !packet.isTcpFlagAck()) {
            if (this.handshakes.containsKey(packet.getFlow())) {
                this.handshakes.remove(packet.getFlow());
                return true;
            }
            TcpHandshake handshake2 = new TcpHandshake(packet.getTcpSeq());
            handshake2.setSynTs(packet.getTsMilli());
            this.handshakes.put(packet.getFlow(), handshake2);
            return true;
        }
        if (server && packet.isTcpFlagSyn() && packet.isTcpFlagAck()) {
            TCPFlow reverseFlow = packet.getReverseFlow();
            TcpHandshake handshake3 = this.handshakes.get(reverseFlow);
            if (handshake3 != null && handshake3.getClientSynSeq() == packet.getTcpAck() - 1L) {
                if (TcpHandshake.HANDSHAKE_STATE.SYN_RECV == handshake3.getState()) {
                    handshake3.setState(TcpHandshake.HANDSHAKE_STATE.SYN_ACK_SENT);
                    handshake3.setServerAckSeq(packet.getTcpAck());
                    handshake3.setServerSynSeq(packet.getTcpSeq());
                } else {
                    this.handshakes.remove(reverseFlow);
                }
            } else if (log.isDebugEnabled()) {
                log.debug("Cannot find handshake for SYN/ACK, maybe a retry?");
            }
            return true;
        }
        if (!server && packet.isTcpFlagAck() && (handshake = this.handshakes.get(packet.getFlow())) != null && TcpHandshake.HANDSHAKE_STATE.SYN_ACK_SENT == handshake.getState() && packet.getTcpAck() - 1L == handshake.getServerSynSeq()) {
            handshake.setAckTs(packet.getTsMilli());
            handshake.setState(TcpHandshake.HANDSHAKE_STATE.ACK_RECV);
            handshake.setClientAckSeq(packet.getTcpSeq());
            if (!packet.isTcpFlagPsh()) {
                return true;
            }
        }
        return false;
    }

    private int getTcpHeaderLength(byte[] packet) {
        if (12 < packet.length) {
            return (packet[12] >> 4 & 0xF) * 4;
        }
        return -1;
    }

    private int dnsMessageLen(ChainBuffer payload) {
        if (payload == null || payload.readableBytes() < 2) {
            if (log.isDebugEnabled()) {
                log.debug("Reading DNS message len from failed failed, only {} bytes remaining", (Object)payload.readableBytes());
            }
            return 0;
        }
        return payload.getShort();
    }

    private ChainBuffer decodeDnsPayload(Packet packet, ChainBuffer buffer) {
        while (buffer.readableBytes() > 2) {
            int len = this.dnsMessageLen(buffer);
            if (len > 0 && buffer.readableBytes() >= len) {
                byte[] data = null;
                int offset = 0;
                if (buffer.readableBytesCurrentBuffer() >= len) {
                    data = buffer.currentBuffer();
                    offset = buffer.getOffset();
                    buffer.position(buffer.position() + len);
                    len += 2;
                } else {
                    if (this.sharedDnsBuffer.length < len) {
                        this.sharedDnsBuffer = new byte[len];
                    }
                    buffer.gets(this.sharedDnsBuffer, 0, len);
                    data = this.sharedDnsBuffer;
                }
                this.dnsDecoder.decode((DNSPacket)packet, data, offset, len);
                continue;
            }
            buffer.position(buffer.position() - 2);
            break;
        }
        if (log.isDebugEnabled() && ((DNSPacket)packet).getMessageCount() > 1) {
            log.debug("multiple msg in TCP stream: {}", (Object)((DNSPacket)packet).getMessageCount());
        }
        return buffer;
    }

    public void clearCache(int cacheTTL) {
        ArrayList<TCPFlow> expiredList = new ArrayList<TCPFlow>();
        long max = this.lastPacketTs - (long)cacheTTL;
        block0: for (Map.Entry<TCPFlow, FlowData> entry : this.flows.entrySet()) {
            for (SequencePayload sequencePayload : entry.getValue().getPayloads()) {
                if (sequencePayload.getTime() >= max) continue;
                expiredList.add(entry.getKey());
                continue block0;
            }
        }
        log.info("------------- TCP Decoder Cache Stats --------------------");
        log.info("TCP flow cache size: " + this.flows.size());
        log.info("Expired (to be removed) TCP flows: " + expiredList.size());
        expiredList.stream().forEach(s -> this.removeFlow((TCPFlow)s));
    }

    @Override
    public void printStats() {
        log.info("---------------------- TCP Decoder Stats -----------------");
        log.info("packetCounter: {}", (Object)this.packetCounter);
        log.info("reqPacketCounter: {}", (Object)this.reqPacketCounter);
        log.info("rspPacketCounter: {}", (Object)this.rspPacketCounter);
        log.info("dnsRspMsgCounter: {}", (Object)this.dnsRspMsgCounter);
        log.info("dnsReqMsgCounter: {}", (Object)this.dnsReqMsgCounter);
    }

    @Override
    public void reset() {
        this.packetCounter = 0;
        this.reqPacketCounter = 0;
        this.rspPacketCounter = 0;
        this.dnsRspMsgCounter = 0;
        this.dnsReqMsgCounter = 0;
        this.dnsDecoder.reset();
    }

    public DNSDecoder getDnsDecoder() {
        return this.dnsDecoder;
    }

    public Map<TCPFlow, FlowData> getFlows() {
        return this.flows;
    }

    public Map<TCPFlow, TcpHandshake> getHandshakes() {
        return this.handshakes;
    }

    public int getPacketCounter() {
        return this.packetCounter;
    }

    public int getReqPacketCounter() {
        return this.reqPacketCounter;
    }

    public int getRspPacketCounter() {
        return this.rspPacketCounter;
    }

    public int getDnsRspMsgCounter() {
        return this.dnsRspMsgCounter;
    }

    public int getDnsReqMsgCounter() {
        return this.dnsReqMsgCounter;
    }

    public ByteBuffer getPacketPayload() {
        return this.packetPayload;
    }

    public byte[] getSharedDnsBuffer() {
        return this.sharedDnsBuffer;
    }

    public long getLastPacketTs() {
        return this.lastPacketTs;
    }

    public void setDnsDecoder(DNSDecoder dnsDecoder) {
        this.dnsDecoder = dnsDecoder;
    }

    public void setFlows(Map<TCPFlow, FlowData> flows) {
        this.flows = flows;
    }

    public void setHandshakes(Map<TCPFlow, TcpHandshake> handshakes) {
        this.handshakes = handshakes;
    }

    public void setPacketCounter(int packetCounter) {
        this.packetCounter = packetCounter;
    }

    public void setReqPacketCounter(int reqPacketCounter) {
        this.reqPacketCounter = reqPacketCounter;
    }

    public void setRspPacketCounter(int rspPacketCounter) {
        this.rspPacketCounter = rspPacketCounter;
    }

    public void setDnsRspMsgCounter(int dnsRspMsgCounter) {
        this.dnsRspMsgCounter = dnsRspMsgCounter;
    }

    public void setDnsReqMsgCounter(int dnsReqMsgCounter) {
        this.dnsReqMsgCounter = dnsReqMsgCounter;
    }

    public void setPacketPayload(ByteBuffer packetPayload) {
        this.packetPayload = packetPayload;
    }

    public void setSharedDnsBuffer(byte[] sharedDnsBuffer) {
        this.sharedDnsBuffer = sharedDnsBuffer;
    }

    public void setLastPacketTs(long lastPacketTs) {
        this.lastPacketTs = lastPacketTs;
    }
}

