/*
 * Copyright (c) 2012-2018 by Zalo Group.
 * All Rights Reserved.
 */
package com.zing.zalo.zbrowser.downloader;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.concurrent.ExecutorService;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;

/**
 *
 * @author datbt
 */
public class ZSSLConnection {
	
	private static final Callback DEFAULT_CALLBACK = new Callback() {
		@Override
		public void onHandshakeSuccess() {
		}

		@Override
		public void onDataResponse(byte[] data) {
		}

		@Override
		public void onConnectionClosed() {
		}
	};

	private final SSLEngine engine;
	private final ExecutorService delegatedTaskWorkers;
	private final ByteBuffer clientWrap, clientUnwrap;
	private final ByteBuffer serverWrap, serverUnwrap;

	private final ByteBuffer internalBuffer;

	private final SocketChannel socket;

	private boolean isHandshakeSuccess = false;
	
	private boolean isConnectionReuse = false;
	
	private Callback sslCallback;

	public ZSSLConnection(SocketChannel socket, SSLContext sslContext, ExecutorService delegatedTaskWorkers, int internalBufferSize, HttpUrl url) throws SSLException {
		SSLEngine sslEngine = sslContext.createSSLEngine(url.domain, url.port);
		sslEngine.setUseClientMode(true);
		sslEngine.beginHandshake();
		
		this.engine = sslEngine;
		this.socket = socket;
		this.delegatedTaskWorkers = delegatedTaskWorkers;

		this.internalBuffer = ByteBuffer.allocate(internalBufferSize);

		int netCapacity = engine.getSession().getPacketBufferSize();
		this.clientWrap = ByteBuffer.allocate(netCapacity);
		this.serverWrap = ByteBuffer.allocate(netCapacity);

		int appCapacity = Math.max(engine.getSession().getApplicationBufferSize(), internalBufferSize);
		this.clientUnwrap = ByteBuffer.allocate(appCapacity);
		this.serverUnwrap = ByteBuffer.allocate(appCapacity);
		this.clientUnwrap.limit(0);
		
		this.sslCallback = DEFAULT_CALLBACK;
	}
	
	public void setCallback(Callback callback) {
		if (callback == null) {
			callback = DEFAULT_CALLBACK;
		}		
		sslCallback = callback;		
	}

	public boolean isHandshakeSuccess() {
		return isHandshakeSuccess;
	}
	
	public void reuseConnection() {
		this.clientWrap.clear();
		this.isConnectionReuse = true;
	}
	
	public boolean isConnectionReuse() {
		return isConnectionReuse;
	}

	public void checkForData() throws IOException {
		if (serverWrap.position() > 0) { //remain data in buffer to write
			if (!internalWrite()) { //still remain data in buffer -- return to wait next event
				return;
			}
		}
		internalRead();
		doHandshake();
	}
	
	public void send(ByteBuffer data) {
		clientWrap.put(data);
	}

	private int internalRead() throws IOException {
		internalBuffer.clear();
		internalBuffer.limit(clientUnwrap.remaining());
		int bytes;

		bytes = socket.read(internalBuffer);

		if (bytes == -1) {
			return -1;
		}

		if (bytes == 0) {
			return 0;
		}

//		System.err.println("read from raw socket: " + bytes + " - cap: " + internalBuffer.capacity());

		internalBuffer.flip();

		clientUnwrap.put(internalBuffer);

		return bytes;

	}

	private boolean internalWrite() throws IOException {
		int bytesToWrite = serverWrap.remaining();
		if (bytesToWrite == 0) {
			return true;
		}
		int write = socket.write(serverWrap);
		if (bytesToWrite == write) {
			serverWrap.compact();
			return true;
		}
		return false;
	}

	private void doHandshake() throws IOException {
		while (isHandshaking()) {
			//wait for handshaking finish
		}
	}

	private synchronized boolean isHandshaking() throws IOException {
		SSLEngineResult.HandshakeStatus handshakeStatus = engine.getHandshakeStatus();

//		System.err.println("handshake status: " + handshakeStatus);

		switch (handshakeStatus) {
			case NOT_HANDSHAKING:
				boolean occupied = false;
				{
					if (clientWrap.position() > 0) {
						occupied |= this.wrap();
					}
					if (clientUnwrap.position() > 0) {
						occupied |= this.unwrap();
					}
				}
				
				if (!occupied && isHandshakeSuccess && isConnectionReuse) {
					//connection is reused -- noneed to wait for handshaking
					isConnectionReuse = false;
					sslCallback.onHandshakeSuccess();
				}
				
				return occupied;

			case NEED_WRAP:
				if (!this.wrap()) {
					return false;
				}
				break;

			case NEED_UNWRAP:
				if (!this.unwrap()) {
					return false;
				}
				break;

			case NEED_TASK:
				final Runnable sslTask = engine.getDelegatedTask();
				if (sslTask != null) {
					if (delegatedTaskWorkers != null) {
						delegatedTaskWorkers.execute(sslTask);
					} else {
						sslTask.run();
					}
				}
				return false;

			case FINISHED:
				throw new IllegalStateException("FINISHED");
		}

		return true;
	}

	private boolean wrap() throws IOException {
		SSLEngineResult wrapResult;
		
		clientWrap.flip();
		wrapResult = engine.wrap(clientWrap, serverWrap);
		clientWrap.compact();		

		SSLEngineResult.Status wrapResultStatus = wrapResult.getStatus();

//		System.err.println("wrap status: " + wrapResult.getStatus() + " - " + wrapResult.getHandshakeStatus());

		switch (wrapResultStatus) {
			case OK:
				if (serverWrap.position() > 0) {
					serverWrap.flip();
					internalWrite();
				}
				break;

			case BUFFER_UNDERFLOW:
				// try again later
				break;

			case BUFFER_OVERFLOW:
				throw new IllegalStateException("failed to wrap");

			case CLOSED:
				sslCallback.onConnectionClosed();
				return false;
		}

		if (wrapResult.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
			if (!isHandshakeSuccess) {
				isHandshakeSuccess = true;
				sslCallback.onHandshakeSuccess();
			}
			return false;
		}

		return true;
	}

	private boolean unwrap() throws SSLException {
		SSLEngineResult unwrapResult;
		
		clientUnwrap.flip();
		unwrapResult = engine.unwrap(clientUnwrap, serverUnwrap);
		clientUnwrap.compact();
		
		SSLEngineResult.Status unwrapResultStatus = unwrapResult.getStatus();

//		System.err.println("unwrap status: " + unwrapResult.getStatus() + " - " + unwrapResult.getHandshakeStatus());

		switch (unwrapResultStatus) {
			case OK:
				if (serverUnwrap.position() > 0) {
					serverUnwrap.flip();

					byte[] dataUnwrap = new byte[serverUnwrap.remaining()];
					serverUnwrap.get(dataUnwrap);
					sslCallback.onDataResponse(dataUnwrap);

					serverUnwrap.compact();
				} else if (unwrapResult.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
					return false;
				}
				break;

			case CLOSED:
				sslCallback.onConnectionClosed();
				return false;

			case BUFFER_OVERFLOW:
				throw new IllegalStateException("failed to unwrap");

			case BUFFER_UNDERFLOW:
				return false;
		}

		if (unwrapResult.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
			if (!isHandshakeSuccess) {
				isHandshakeSuccess = true;
				sslCallback.onHandshakeSuccess();
			}
			return false;
		}

		return true;
	}
	
	public static interface Callback {
		public abstract void onHandshakeSuccess();
		public abstract void onDataResponse(byte[] data);
		public abstract void onConnectionClosed();
	}

}
