package cool.scx.socket;

import cool.scx.util.ObjectUtils;
import io.netty.util.Timeout;
import io.vertx.core.Future;
import io.vertx.core.http.WebSocketBase;

import java.lang.System.Logger;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;

import static cool.scx.socket.ScxSocketFrame.fromJson;
import static cool.scx.socket.ScxSocketFrameType.*;
import static cool.scx.socket.ScxSocketHelper.setTimeout;
import static cool.scx.socket.SendOptions.DEFAULT_SEND_OPTIONS;
import static cool.scx.util.StringUtils.isBlank;
import static java.lang.System.Logger.Level.DEBUG;

public class ScxSocket {

    private static final ScxSocketFrame PING_FRAME = createPingFrame();
    private static final ScxSocketFrame PONG_FRAME = createPongFrame();
    protected final Logger logger = System.getLogger(this.getClass().getName());
    final ConcurrentMap<Long, SendTask> sendTaskMap;
    final ConcurrentMap<Long, SeqIDClearTask> seqIDClearTaskMap;
    final ScxSocketOptions options;
    private final ConcurrentMap<String, Consumer<String>> eventHandlerMap;
    private final AtomicLong seqID;
    protected WebSocketBase webSocket;
    private Timeout ping;
    private Timeout pingTimeout;
    private Consumer<String> onMessage;
    private Consumer<Void> onClose;
    private Consumer<Throwable> onError;

    public ScxSocket(ScxSocketOptions options) {
        this.options = options;
        this.seqID = new AtomicLong(0);
        this.sendTaskMap = new ConcurrentHashMap<>();
        this.eventHandlerMap = new ConcurrentHashMap<>();
        this.seqIDClearTaskMap = new ConcurrentHashMap<>();
    }

    private static ScxSocketFrame createAckFrame(long id) {
        var ackFrame = new ScxSocketFrame();
        ackFrame.seq_id = 0;
        ackFrame.type = ACK;
        ackFrame.now = 0;
        ackFrame.payload = Long.toString(id);
        return ackFrame;
    }

    private static ScxSocketFrame createPingFrame() {
        var pingFrame = new ScxSocketFrame();
        pingFrame.seq_id = 0;
        pingFrame.type = PING;
        pingFrame.now = 0;
        pingFrame.payload = "";
        return pingFrame;
    }

    private static ScxSocketFrame createPongFrame() {
        var pongFrame = new ScxSocketFrame();
        pongFrame.seq_id = 0;
        pongFrame.type = PONG;
        pongFrame.now = 0;
        pongFrame.payload = "";
        return pongFrame;
    }

    private ScxSocketFrame createMessageFrame(String content, SendOptions options) {
        var messageFrame = new ScxSocketFrame();
        messageFrame.seq_id = this.seqID.getAndIncrement();
        messageFrame.type = options.getNeedAck() ? MESSAGE_NEED_ACK : MESSAGE;
        messageFrame.now = System.currentTimeMillis();
        messageFrame.payload = content;
        return messageFrame;
    }

    private ScxSocketFrame createEventFrame(String eventName, Object data, SendOptions options) {
        var eventFrame = new ScxSocketFrame();
        eventFrame.seq_id = this.seqID.getAndIncrement();
        eventFrame.type = options.getNeedAck() ? MESSAGE_NEED_ACK : MESSAGE;
        eventFrame.now = System.currentTimeMillis();
        eventFrame.event_name = eventName;
        eventFrame.payload = ObjectUtils.toJson(data, "");
        return eventFrame;
    }

    private void send(ScxSocketFrame socketFrame, SendOptions options) {
        var sendTask = new SendTask(socketFrame, this, options);
        this.sendTaskMap.put(socketFrame.seq_id, sendTask);
        sendTask.start();
    }

    public void send(String content) {
        send(createMessageFrame(content, DEFAULT_SEND_OPTIONS), DEFAULT_SEND_OPTIONS);
    }

    public void send(String content, SendOptions options) {
        send(createMessageFrame(content, options), options);
    }

    public void sendEvent(String eventName, Object data) {
        send(createEventFrame(eventName, data, DEFAULT_SEND_OPTIONS), DEFAULT_SEND_OPTIONS);
    }

    public void sendEvent(String eventName, Object data, SendOptions options) {
        send(createEventFrame(eventName, data, options), options);
    }

    private void sendAck(long id) {
        this.webSocket.writeTextMessage(createAckFrame(id).toJson());
    }

    private Future<Void> sendPing() {
        return this.webSocket.writeTextMessage(PING_FRAME.toJson());
    }

    protected void sendPong() {
        this.webSocket.writeTextMessage(PONG_FRAME.toJson());
    }

    private void startAllSendTask() {
        for (var value : sendTaskMap.values()) {
            value.start();
        }
    }

    private void cancelAllResendTask() {
        for (var value : sendTaskMap.values()) {
            value.cancelResend();
        }
    }

    private void startAllSendTaskAsync() {
        Thread.ofVirtual().start(this::startAllSendTask);
    }

    private void cancelAllResendTaskAsync() {
        Thread.ofVirtual().start(this::cancelAllResendTask);
    }

    private void bind(WebSocketBase webSocket) {
        this.webSocket = webSocket;
        this.webSocket.textMessageHandler(t -> doSocketFrame(fromJson(t)));
        this.webSocket.closeHandler(this::doClose);
        this.webSocket.exceptionHandler(this::doError);
    }

    private void removeBind() {
        if (this.webSocket != null && !this.webSocket.isClosed()) {
            this.webSocket.textMessageHandler(null);
            this.webSocket.closeHandler(null);
            this.webSocket.exceptionHandler(null);
        }
    }

    protected void doSocketFrame(ScxSocketFrame socketFrame) {
        //只要收到任何消息就重置 心跳 
        startPing();
        startPingTimeout();
        switch (socketFrame.type) {
            case MESSAGE -> doMessage(socketFrame);
            case MESSAGE_NEED_ACK -> doMessageNeedAck(socketFrame);
            case ACK -> doAck(socketFrame);
            case PING -> doPing(socketFrame);
            case PONG -> doPong(socketFrame);
        }
    }

    private void doMessage(ScxSocketFrame socketFrame) {
        if (isBlank(socketFrame.event_name)) {
            callOnMessageWithCheckDuplicateAsync(socketFrame);
        } else {
            callOnEventWithCheckDuplicateAsync(socketFrame);
        }
    }

    private void doMessageNeedAck(ScxSocketFrame socketFrame) {
        sendAck(socketFrame.seq_id);
        doMessage(socketFrame);
    }

    protected void doAck(ScxSocketFrame ackFrame) {
        var seqID = Long.parseLong(ackFrame.payload);
        var sendTask = sendTaskMap.get(seqID);
        if (sendTask != null) {
            sendTask.clear();
        }
    }

    private void doPing(ScxSocketFrame socketFrame) {
        sendPong();
        logger.log(DEBUG, "收到 ping");
    }

    private void doPong(ScxSocketFrame socketFrame) {
        //什么都不做
        logger.log(DEBUG, "收到 pong");
    }

    protected void doClose(Void v) {
        this.close();
        //呼叫 onClose 事件
        this.callOnClose(v);
    }

    protected void doError(Throwable e) {
        this.close();
        //呼叫 onClose 事件
        this.callOnError(e);
    }

    private void closeWebSocket() {
        if (this.webSocket != null && !this.webSocket.isClosed()) {
            this.webSocket.close();
        }
    }

    private void startPingTimeout() {
        cancelPingTimeout();
        this.pingTimeout = setTimeout(this::doPingTimeout, options.getPingTimeout() + options.getPingInterval());
    }

    private void cancelPingTimeout() {
        if (this.pingTimeout != null) {
            this.pingTimeout.cancel();
            this.pingTimeout = null;
        }
    }

    protected void startPing() {
        cancelPing();
        this.ping = setTimeout(() -> {
            sendPing();
            startPing();
        }, options.getPingInterval());
    }

    private void cancelPing() {
        if (this.ping != null) {
            this.ping.cancel();
            this.ping = null;
        }
    }

    void start(WebSocketBase webSocket) {
        close();
        //绑定事件
        this.bind(webSocket);
        //启动所有发送任务
        this.startAllSendTask();
        //启动心跳
        this.startPing();
        //心跳超时
        this.startPingTimeout();
        //启动 校验重复清除任务
        this.startAllClearTask();
    }

    public void close() {
        //移除绑定事件
        this.removeBind();
        //关闭 连接
        this.closeWebSocket();
        //取消所有重发任务
        this.cancelAllResendTask();
        //取消心跳
        this.cancelPing();
        //取消心跳超时
        this.cancelPingTimeout();
        //取消 校验重复清除任务
        this.cancelAllClearTask();
    }

    protected void doPingTimeout() {
        this.close();
    }

    public boolean isClosed() {
        return webSocket == null || webSocket.isClosed();
    }

    public ScxSocket onMessage(Consumer<String> onMessage) {
        this.onMessage = onMessage;
        return this;
    }

    public ScxSocket onClose(Consumer<Void> onClose) {
        this.onClose = onClose;
        return this;
    }

    public ScxSocket onError(Consumer<Throwable> onError) {
        this.onError = onError;
        return this;
    }

    public ScxSocket onEvent(String eventName, Consumer<String> onEvent) {
        this.eventHandlerMap.put(eventName, onEvent);
        return this;
    }

    public <T> ScxSocket onEvent(String eventName, Class<T> dataClass, Consumer<T> onEvent) {
        this.eventHandlerMap.put(eventName, (s) -> onEvent.accept(ScxSocketHelper.fromJson(s, dataClass)));
        return this;
    }

    private void callOnMessage(String message) {
        if (this.onMessage != null) {
            this.onMessage.accept(message);
        }
    }

    private void callOnMessage(ScxSocketFrame socketFrame) {
        if (this.onMessage != null) {
            this.onMessage.accept(socketFrame.payload);
        }
    }

    private void callOnMessageWithCheckDuplicate(ScxSocketFrame socketFrame) {
        if (this.onMessage != null && checkDuplicate(socketFrame)) {
            this.onMessage.accept(socketFrame.payload);
        }
    }

    private void callOnClose(Void v) {
        if (this.onClose != null) {
            this.onClose.accept(v);
        }
    }

    private void callOnError(Throwable e) {
        if (this.onError != null) {
            this.onError.accept(e);
        }
    }

    private void callOnEvent(String eventName, String data) {
        var eventHandler = this.eventHandlerMap.get(eventName);
        if (eventHandler != null) {
            eventHandler.accept(data);
        }
    }

    private void callOnEvent(ScxSocketFrame socketFrame) {
        var eventHandler = this.eventHandlerMap.get(socketFrame.event_name);
        if (eventHandler != null) {
            eventHandler.accept(socketFrame.payload);
        }
    }

    private void callOnEventWithCheckDuplicate(ScxSocketFrame socketFrame) {
        var eventHandler = this.eventHandlerMap.get(socketFrame.event_name);
        if (eventHandler != null && checkDuplicate(socketFrame)) {
            eventHandler.accept(socketFrame.payload);
        }
    }

    private void callOnMessageAsync(String message) {
        if (this.onMessage != null) {
            Thread.ofVirtual().start(() -> this.onMessage.accept(message));
        }
    }

    private void callOnMessageAsync(ScxSocketFrame socketFrame) {
        if (this.onMessage != null) {
            Thread.ofVirtual().start(() -> this.onMessage.accept(socketFrame.payload));
        }
    }

    private void callOnMessageWithCheckDuplicateAsync(ScxSocketFrame socketFrame) {
        if (this.onMessage != null && checkDuplicate(socketFrame)) {
            Thread.ofVirtual().start(() -> this.onMessage.accept(socketFrame.payload));
        }
    }

    private void callOnCloseAsync(Void v) {
        if (this.onClose != null) {
            Thread.ofVirtual().start(() -> this.onClose.accept(v));
        }
    }

    private void callOnErrorAsync(Throwable e) {
        if (this.onError != null) {
            Thread.ofVirtual().start(() -> this.onError.accept(e));
        }
    }

    private void callOnEventAsync(String eventName, String data) {
        var eventHandler = this.eventHandlerMap.get(eventName);
        if (eventHandler != null) {
            Thread.ofVirtual().start(() -> eventHandler.accept(data));
        }
    }

    private void callOnEventAsync(ScxSocketFrame socketFrame) {
        var eventHandler = this.eventHandlerMap.get(socketFrame.event_name);
        if (eventHandler != null) {
            Thread.ofVirtual().start(() -> eventHandler.accept(socketFrame.payload));
        }
    }

    private void callOnEventWithCheckDuplicateAsync(ScxSocketFrame socketFrame) {
        var eventHandler = this.eventHandlerMap.get(socketFrame.event_name);
        if (eventHandler != null && checkDuplicate(socketFrame)) {
            Thread.ofVirtual().start(() -> eventHandler.accept(socketFrame.payload));
        }
    }

    /**
     * 用来判断是否为重发的消息
     *
     * @param socketFrame socketFrame
     * @return true 是重发 false 不是重发
     */
    private boolean checkDuplicate(ScxSocketFrame socketFrame) {
        //只要 MESSAGE_NEED_ACK 的可能会重发 所以需要 做校验
        if (socketFrame.type != MESSAGE_NEED_ACK) {
            return true;
        }
        var seqID = socketFrame.seq_id;
        var task = seqIDClearTaskMap.get(seqID);
        if (task == null) {
            var seqIDClearTask = new SeqIDClearTask(seqID, this);
            seqIDClearTaskMap.put(seqID, seqIDClearTask);
            seqIDClearTask.start();
            return true;
        } else {
            return false;
        }
    }

    private void startAllClearTask() {
        for (var value : seqIDClearTaskMap.values()) {
            value.start();
        }
    }

    private void cancelAllClearTask() {
        for (var value : seqIDClearTaskMap.values()) {
            value.cancel();
        }
    }

    private void startAllClearTaskAsync() {
        Thread.ofVirtual().start(this::startAllClearTask);
    }

    private void cancelAllClearTaskAsync() {
        Thread.ofVirtual().start(this::cancelAllClearTaskAsync);
    }

}
