/*
 * Decompiled with CFR 0.152.
 */
package org.teamapps.universaldb.cluster.network;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.Socket;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.Key;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.concurrent.ArrayBlockingQueue;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.teamapps.universaldb.cluster.message.ClusterMessage;
import org.teamapps.universaldb.cluster.network.ConnectionHandler;
import org.teamapps.universaldb.cluster.network.MessageType;
import org.teamapps.universaldb.cluster.network.NetworkWriter;

public class NodeConnection
implements NetworkWriter {
    private static final Logger log = LoggerFactory.getLogger(NodeConnection.class);
    private String clusterSecret;
    private ConnectionHandler connectionHandler;
    private Socket socket;
    private boolean initialized;
    private Cipher encryptionCipher;
    private Cipher decryptionCipher;
    private byte[] random1 = new byte[8];
    private byte[] random2 = new byte[8];
    private byte[] random3 = new byte[8];
    private byte[] random4 = new byte[8];
    private volatile boolean running = true;
    private ClusterMessage initialMessage;
    private ArrayBlockingQueue<ClusterMessage> messageQueue;

    public NodeConnection(String clusterSecret, ConnectionHandler connectionHandler) {
        this.clusterSecret = clusterSecret;
        this.connectionHandler = connectionHandler;
        this.messageQueue = new ArrayBlockingQueue(100000);
    }

    public void setSocket(Socket socket) {
        try {
            this.socket = socket;
            socket.setKeepAlive(true);
            socket.setTcpNoDelay(true);
            this.startReaderThread();
            this.startWriterThread();
        }
        catch (SocketException e) {
            this.connectionError("Error setting socket options:" + e.getMessage());
        }
    }

    public void connect(String address, int port, ClusterMessage initialMessage) {
        try {
            Socket socket;
            this.initialMessage = initialMessage;
            this.socket = socket = new Socket(address, port);
            socket.setKeepAlive(true);
            socket.setTcpNoDelay(true);
            SecureRandom secureRandom = new SecureRandom();
            secureRandom.nextBytes(this.random1);
            secureRandom.nextBytes(this.random2);
            byte[] bytes = this.combineByteArrays(this.random1, this.random2);
            this.sendMessage(MessageType.ENCRYPTION_INIT, bytes);
            this.startReaderThread();
            this.startWriterThread();
        }
        catch (IOException e) {
            this.connectionError("Error connecting to " + address + ", " + e.getMessage());
        }
    }

    @Override
    public void closeConnection(String reason) {
        this.connectionError(reason);
    }

    @Override
    public void sendMessage(ClusterMessage clusterMessage) {
        if (!this.messageQueue.offer(clusterMessage)) {
            this.connectionError("Full queue - need to check cluster node state");
        }
    }

    private void sendMessage(MessageType messageType, byte[] data) {
        try {
            byte[] packet;
            if (this.initialized) {
                byte[] bytes = this.encryptionCipher.doFinal(data);
                packet = new byte[bytes.length + 5];
                ByteBuffer buffer = ByteBuffer.wrap(packet);
                buffer.put((byte)messageType.getMessageId());
                buffer.putInt(bytes.length);
                buffer.put(bytes);
            } else {
                packet = new byte[data.length + 5];
                ByteBuffer buffer = ByteBuffer.wrap(packet);
                buffer.put((byte)messageType.getMessageId());
                buffer.putInt(data.length);
                buffer.put(data);
            }
            this.socket.getOutputStream().write(packet);
            this.socket.getOutputStream().flush();
        }
        catch (Exception e) {
            this.connectionError(e.getMessage());
        }
    }

    @Override
    public void setConnectionHandler(ConnectionHandler connectionHandler) {
        this.connectionHandler = connectionHandler;
    }

    private void createCiphers() {
        try {
            MessageDigest md = MessageDigest.getInstance("MD5");
            byte[] ivBytes = this.combineByteArrays(this.random1, this.random3);
            byte[] keyInput = this.combineByteArrays(this.random2, this.clusterSecret.getBytes(StandardCharsets.UTF_8), this.random4);
            byte[] keyBytes = md.digest(keyInput);
            this.clusterSecret = null;
            this.encryptionCipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
            this.encryptionCipher.init(1, (Key)new SecretKeySpec(keyBytes, "AES"), new IvParameterSpec(ivBytes));
            this.decryptionCipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
            this.decryptionCipher.init(2, (Key)new SecretKeySpec(keyBytes, "AES"), new IvParameterSpec(ivBytes));
            this.initialized = true;
        }
        catch (Exception e) {
            this.connectionError(e.getMessage());
        }
    }

    private byte[] combineByteArrays(byte[] ... byteArrays) {
        int len = Arrays.stream(byteArrays).mapToInt(array -> ((byte[])array).length).sum();
        ByteArrayOutputStream bos = new ByteArrayOutputStream(len);
        for (byte[] byteArray : byteArrays) {
            bos.write(byteArray, 0, byteArray.length);
        }
        return bos.toByteArray();
    }

    private void startReaderThread() {
        new Thread(() -> {
            try {
                boolean readHeader = true;
                byte[] packetSignature = new byte[5];
                MessageType messageType = null;
                byte[] data = null;
                int pos = 0;
                InputStream inputStream = this.socket.getInputStream();
                while (this.running) {
                    int read;
                    if (readHeader) {
                        int len;
                        if ((pos += (read = inputStream.read(packetSignature, pos, packetSignature.length - pos))) != read) continue;
                        readHeader = false;
                        ByteBuffer buffer = ByteBuffer.wrap(packetSignature);
                        byte messageId = buffer.get();
                        messageType = MessageType.getById(messageId);
                        if (messageType == null) {
                            this.connectionError("Unknown message type:" + messageId);
                        }
                        if ((len = buffer.getInt()) <= 0 || len > 1000000000) {
                            this.connectionError("Invalid message size:" + len);
                            return;
                        }
                        data = new byte[len];
                        pos = 0;
                        continue;
                    }
                    if ((pos += (read = inputStream.read(data, pos, data.length - pos))) != read) continue;
                    if (this.initialized) {
                        byte[] bytes = this.decryptionCipher.doFinal(data);
                        this.handleMessage(messageType, bytes);
                    } else {
                        this.handleMessage(messageType, data);
                    }
                    readHeader = true;
                    pos = 0;
                    messageType = null;
                }
            }
            catch (IOException e) {
                this.connectionError("Error reading data:" + e.getMessage());
            }
            catch (BadPaddingException e) {
                this.connectionError("Padding error:" + e.getMessage());
            }
            catch (IllegalBlockSizeException e) {
                this.connectionError("Block size error:" + e.getMessage());
            }
        }, "node-connection-reader-" + this.socket.getInetAddress().getHostAddress() + "-" + this.socket.getPort()).start();
    }

    private void startWriterThread() {
        new Thread(() -> {
            while (this.running) {
                try {
                    ClusterMessage clusterMessage = this.messageQueue.take();
                    this.sendMessage(clusterMessage.getType(), clusterMessage.getData());
                }
                catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }, "node-connection-writer" + this.socket.getInetAddress().getHostAddress() + "-" + this.socket.getPort()).start();
    }

    private void connectionError(String msg) {
        this.running = false;
        log.info("Connection error:" + msg);
        try {
            if (this.socket != null) {
                this.socket.close();
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        this.connectionHandler.handleConnectionError();
    }

    private void handleMessage(MessageType messageType, byte[] data) {
        if (!this.initialized) {
            if (messageType == MessageType.ENCRYPTION_INIT) {
                ByteBuffer buffer = ByteBuffer.wrap(data);
                buffer.get(this.random1);
                buffer.get(this.random2);
                SecureRandom secureRandom = new SecureRandom();
                secureRandom.nextBytes(this.random3);
                secureRandom.nextBytes(this.random4);
                byte[] bytes = this.combineByteArrays(this.random3, this.random4);
                this.sendMessage(MessageType.ENCRYPTION_INIT_RESPONSE, bytes);
                this.createCiphers();
                this.connectionHandler.handleConnected(this);
            } else if (messageType == MessageType.ENCRYPTION_INIT_RESPONSE) {
                ByteBuffer buffer = ByteBuffer.wrap(data);
                buffer.get(this.random3);
                buffer.get(this.random4);
                this.createCiphers();
                this.connectionHandler.handleConnected(this);
                if (this.initialMessage != null) {
                    this.sendMessage(this.initialMessage);
                }
            }
        } else {
            this.connectionHandler.handleMessage(messageType, data, this);
        }
    }
}

