package kz.greetgo.spring.websocket.beans;

import com.fasterxml.jackson.databind.ObjectMapper;
import kz.greetgo.spring.websocket.controller.AnnotationFinder;
import kz.greetgo.spring.websocket.controller.ControllerManager;
import kz.greetgo.spring.websocket.controller.ExecuteInput;
import kz.greetgo.spring.websocket.controller.PreExecuteInterceptor;
import kz.greetgo.spring.websocket.interfaces.MessageSender;
import kz.greetgo.spring.websocket.interfaces.NeedClose;
import kz.greetgo.spring.websocket.interfaces.WebsocketController;
import kz.greetgo.spring.websocket.model.ToClient;
import kz.greetgo.spring.websocket.model.ToServer;
import kz.greetgo.spring.websocket.util.ConsoleColors;
import kz.greetgo.spring.websocket.util.LoggingUtil;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.AbstractWebSocketHandler;

import java.io.IOException;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;

public abstract class AbstractWebSocketHandlerBean<UserSessionData extends NeedClose>
  extends AbstractWebSocketHandler implements InitializingBean, PreExecuteInterceptor {

  @RequiredArgsConstructor
  private static class SessionData<UserSessionData extends NeedClose> {
    final WebSocketSession webSocketSession;
    final AtomicReference<String> token = new AtomicReference<>();
    final AtomicReference<UserSessionData> userData = new AtomicReference<>();
  }

  private final ConcurrentHashMap<String, SessionData<UserSessionData>> sessionMap = new ConcurrentHashMap<>();

  private final Logger log = LoggerFactory.getLogger(getClass());
  private static final Logger callingLog = LoggingUtil.callingLog;

  private final ControllerManager controllerManager = new ControllerManager();

  @SuppressWarnings("SpringJavaAutowiredMembersInspection")
  @Autowired
  private ApplicationContext applicationContext;

  @Override
  public void afterPropertiesSet() {
    controllerManager.setPreExecuteInterceptorSupplier(() -> this);
    Collection<Object> controllers = applicationContext.getBeansWithAnnotation(WebsocketController.class).values();
    for (Object controller : controllers) {
      controllerManager.registerController(controller);
    }
  }

  @Override
  public void handleTransportError(@NonNull WebSocketSession session, @NonNull Throwable exception) {
    log.error("0hbnSRlbME :: TransportError in " + session, exception);
  }

  @Override
  public void afterConnectionEstablished(@NonNull WebSocketSession webSocketSession) {
    var sessionData = new SessionData<UserSessionData>(webSocketSession);
    if (callingLog.isInfoEnabled()) {
      callingLog.info(ConsoleColors.GREEN_BOLD() + "OPEN SESS " + ConsoleColors.RESET() + webSocketSession.getId());
    }
    var oldSessionData = sessionMap.put(webSocketSession.getId(), sessionData);
    if (oldSessionData != null) {
      closeSession(webSocketSession.getId(), oldSessionData);
    }

    UserSessionData userData = createUserDataOnNewSession(webSocketSession.getId());
    if (userData != null) {
      sessionData.userData.set(userData);
    }
  }

  @SuppressWarnings({"unused", "RedundantSuppression"})
  protected UserSessionData createUserDataOnNewSession(String sessionId) {
    return null;
  }

  private void closeSession(@NonNull String sessionId, @NonNull SessionData<UserSessionData> sessionData) {
    try {
      sessionData.webSocketSession.close();
      UserSessionData userSessionData = sessionData.userData.get();
      closeUserSessionData(sessionId, userSessionData);
    } catch (Exception e) {
      if (e instanceof RuntimeException) {
        throw (RuntimeException) e;
      }
      throw new RuntimeException(e);
    }
  }

  @SuppressWarnings({"RedundantThrows", "unused", "RedundantSuppression"})
  protected void closeUserSessionData(@NonNull String sessionId, UserSessionData userSessionData) throws Exception {
    if (userSessionData != null) {
      userSessionData.close();
    }
  }

  @Override
  public void afterConnectionClosed(@NonNull WebSocketSession session, @NonNull CloseStatus status) {
    if (callingLog.isInfoEnabled()) {
      callingLog.info(ConsoleColors.RED_BOLD() + "CLOS SESS " + ConsoleColors.RESET() + session.getId() +
              " with status: " + status);
    }
    removeSessionById(session.getId());
  }

  protected void removeSessionById(String sessionId) {
    SessionData<UserSessionData> sessionData = sessionMap.remove(sessionId);
    if (sessionData != null) {
      closeSession(sessionId, sessionData);
    }
  }

  @Override
  protected void handleTextMessage(@NonNull WebSocketSession session, @NonNull TextMessage message) throws Exception {

    ObjectMapper objectMapper = new ObjectMapper();
    ToServer toServer = objectMapper.readValue(message.getPayload(), ToServer.class);

    ExecuteInput executeInput = ExecuteInput
      .builder()
      .sessionIdSupplier(session::getId)
      .setParams(toServer.params)
      .build();

    if (callingLog.isInfoEnabled()) {
      callingLog.info(ConsoleColors.BLUE_BOLD() + "TO_SERVER" + ConsoleColors.RESET()
        + " "
        + session.getId()
        + " "
        + ConsoleColors.BLUE() + toServer.service + ConsoleColors.RESET()
        + " : "
        + toServer.params
      );
    }

    controllerManager.findAndExecuteService(toServer.service, executeInput);
  }

  @SuppressWarnings({"unused", "RedundantSuppression"})
  public void sendToClient(@NonNull String sessionId, @NonNull ToClient toClient) {
    createMessageSender(sessionId).sendMessage(toClient);
  }

  private @NonNull SessionData<UserSessionData> sessionData(String sessionId) {
    SessionData<UserSessionData> sessionData = sessionMap.get(sessionId);
    if (sessionData == null) {
      throw new RuntimeException("77ioMM7DGw :: No session with id = " + sessionId);
    }
    return sessionData;
  }

  @SuppressWarnings({"unused", "RedundantSuppression"})
  public String getToken(@NonNull String sessionId) {
    return sessionData(sessionId).token.get();
  }

  @SuppressWarnings({"unused", "RedundantSuppression"})
  public void setToken(@NonNull String sessionId, String value) {
    sessionData(sessionId).token.set(value);
  }

  @SuppressWarnings({"unused", "RedundantSuppression"})
  protected UserSessionData getUserSessionData(@NonNull String sessionId) {
    SessionData<UserSessionData> sessionData = sessionMap.get(sessionId);
    if (sessionData == null) {
      throw new RuntimeException("KT6TIfC8j2 :: No session with id = " + sessionId);
    }
    return sessionData.userData.get();
  }

  @SuppressWarnings({"unused", "RedundantSuppression"})
  protected void setUserSessionData(@NonNull String sessionId, UserSessionData userSessionData) {
    SessionData<UserSessionData> sessionData = sessionMap.get(sessionId);
    if (sessionData == null) {
      throw new RuntimeException("2p2euVyXsR :: No session with id = " + sessionId);
    }
    sessionData.userData.set(userSessionData);
  }

  public @NonNull MessageSender createMessageSender(@NonNull String sessionId) {
    return toClient -> {
      try {
        ObjectMapper objectMapper = new ObjectMapper();
        SessionData<UserSessionData> sessionData = sessionData(sessionId);
        WebSocketSession webSocketSession = sessionData.webSocketSession;
        //noinspection SynchronizationOnLocalVariableOrMethodParameter
        synchronized (webSocketSession) {
          webSocketSession.sendMessage(new TextMessage(objectMapper.writeValueAsString(toClient)));
        }

        if (callingLog.isInfoEnabled()) {
          callingLog.info(ConsoleColors.PURPLE_BOLD() + "TO_CLIENT" + ConsoleColors.RESET()
            + ' '
            + sessionId
            + ' '
            + ConsoleColors.PURPLE() + toClient.service + ConsoleColors.RESET()
            + " : "
            + toClient.body);
        }
      } catch (IOException e) {
        throw new RuntimeException(e);
      }
    };
  }

  @Override
  public void preExecute(Object controller, Method method, String serviceFullName,
                         ExecuteInput executeInput, AnnotationFinder annotationFinder) {}

}
