package app.valuationcontrol.multimodule.library.helpers.openai;

import app.valuationcontrol.multimodule.library.entities.Model;
import app.valuationcontrol.multimodule.library.entities.Segment;
import app.valuationcontrol.multimodule.library.entities.Variable;
import app.valuationcontrol.multimodule.library.helpers.exceptions.ResourceException;
import app.valuationcontrol.multimodule.library.helpers.openai.OpenAiServiceImplementation;
import app.valuationcontrol.multimodule.library.records.CalculationData;
import app.valuationcontrol.multimodule.library.records.VariableData;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatFunctionCall;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.service.FunctionExecutor;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.List;
import java.util.Objects;
import lombok.extern.log4j.Log4j2;
import org.apache.poi.ss.usermodel.DateUtil;
import org.springframework.http.HttpStatus;

@Log4j2
public class OpenAIHelperFunctions {
  public static ObjectNode prepareSingleVariable(
      CalculationData calculationData,
      Model model,
      Variable variable,
      boolean includeDependencies) {

    ObjectMapper objectMapper = new ObjectMapper();
    ObjectNode _rootNode = objectMapper.createObjectNode();
    // Creating modelnode
    ObjectNode modelNode = objectMapper.createObjectNode();
    modelNode.put("company", model.getCompany());
    modelNode.put("currency", model.getCurrency());
    modelNode.put("name", model.getName());
    modelNode.put("nbHistoricalPeriod", model.getNbHistoricalPeriod());
    modelNode.put("nbProjectionPeriod", model.getNbProjectionPeriod());
    _rootNode.set("model", modelNode);

    ArrayNode variableDependenciesArray = objectMapper.createArrayNode();

    calculationData.getVariables().stream()
        .filter(
            v -> {
              boolean currentVariable = Objects.equals(v.id(), variable.getId());
              boolean isDependency =
                  variable.getVariableDependencies().stream()
                      .anyMatch(dep -> Objects.equals(dep.getId(), v.id()));
              return currentVariable || (includeDependencies && isDependency);
            })
        .forEach(
            v -> {
              ObjectNode variableNode = objectMapper.createObjectNode();
              variableNode.put("variableName", v.variableName());
              variableNode.put("variableId", v.id());
              variableNode.put("formula", v.variableFormula());

              if (v.isSingleEntry()) {
                Double constantValue = getDoubleFromObject(v.singleOrConstantValue().get(0), v.variableFormat());
                if (Objects.nonNull(constantValue)) {
                  variableNode.put("value", constantValue);
                  if (variable.isModelledAtSegment()
                      && v.singleOrConstantValue().size() == (model.getSegments().size() + 1)) {

                    ArrayNode jsonArray = objectMapper.createArrayNode();
                    int i = 1;
                    for (Segment seg : model.getSegments()) {
                      ObjectNode segmentVariableValue = objectMapper.createObjectNode();
                      Double value = getDoubleFromObject(v.singleOrConstantValue().get(i), v.variableFormat());
                      if (Objects.nonNull(value)) {
                        segmentVariableValue.put("segmentName", seg.getSegmentName());
                        segmentVariableValue.put("value", value);
                        jsonArray.add(segmentVariableValue);
                      }
                      i = i + 1;
                    }
                    variableNode.set("segmentValues", jsonArray);
                  }
                }
              } else {
                // Populating values
                ArrayNode valuesArray = objectMapper.createArrayNode();
                for (int i = 0; i < model.getNbProjectionPeriod(); i++) {
                  ObjectNode valueNode = objectMapper.createObjectNode();
                  Double value = getDoubleFromObject(v.projectionValues().get(0)[i], v.variableFormat());
                  if (Objects.nonNull(value)) {
                    valueNode.put(String.valueOf(model.getStartYear() + i), value);
                    valuesArray.add(valueNode);
                  }
                }
                variableNode.set("values", valuesArray);

                if (variable.isModelledAtSegment()
                    && v.projectionValues().size() == (model.getSegments().size() + 1)) {
                  ArrayNode segmentArray = objectMapper.createArrayNode();
                  int i = 1;
                  for (Segment seg : model.getSegments()) {
                    ObjectNode segmentNode = objectMapper.createObjectNode();
                    segmentNode.put("segmentName", seg.getSegmentName());
                    ArrayNode segmentValuesArray = objectMapper.createArrayNode();
                    for (int j = 0; j < model.getNbProjectionPeriod(); j++) {
                      ObjectNode valueNode = objectMapper.createObjectNode();
                      Double value = getDoubleFromObject(v.projectionValues().get(i)[j], v.variableFormat());
                      if (Objects.nonNull(value)) {
                        valueNode.put(String.valueOf(model.getStartYear() + j), value);
                        segmentValuesArray.add(valueNode);
                      }
                    }
                    segmentNode.set("values", segmentValuesArray);
                    segmentArray.add(segmentNode);
                    i = i + 1;
                  }
                  variableNode.set("segmentValues", segmentArray);
                }
              }
              if (Objects.equals(v.id(), variable.getId())) {
                _rootNode.set("variable", variableNode);
              } else {
                variableDependenciesArray.add(variableNode);
              }
            });
    // Adding depedency array
    if (!variableDependenciesArray.isEmpty()) {
      _rootNode.set("dependencies", variableDependenciesArray);
    }

    return _rootNode;
  }

  public static Double getDoubleFromObject(Object valueObject, String variableFormat) {
    try {
      Objects.requireNonNull(valueObject);
      int scale = 1;
      if (variableFormat.equalsIgnoreCase("percent")) scale = 4;

      BigDecimal returnValue;
      if (valueObject instanceof BigDecimal) {
        returnValue = (BigDecimal) valueObject;
      } else if (valueObject instanceof Integer) {
        returnValue = BigDecimal.valueOf(((Integer) valueObject).doubleValue());
      } else if (valueObject instanceof Float) {
        returnValue = BigDecimal.valueOf(((Float) valueObject).doubleValue());
      } else {
        returnValue = BigDecimal.valueOf((Double) valueObject);
      }
      returnValue = returnValue.setScale(scale, RoundingMode.HALF_UP);
      return returnValue.doubleValue();

    } catch (ClassCastException | NullPointerException e) {
      // log.debug("Couldn't convert value" + valueObject + " for " + variableData.variableName());
      return null;
    }
  }

  public static String getFormattedValueAsString(Object valueObject, VariableData variableData) {

    Double value = getDoubleFromObject(valueObject, variableData.variableFormat());
    String suffix = "";
    String prefix = "";
    String returnString = " ";

    if (value == null) return returnString;
    if (variableData.variableFormat().equalsIgnoreCase("date")) {
      SimpleDateFormat simpleDateFormat = new SimpleDateFormat("yyyy-MM-dd");
      Date date = DateUtil.getJavaDate(value);
      return simpleDateFormat.format(date);
    }

    if (variableData.variableFormat().equalsIgnoreCase("percent")) {
      value = value * 100D;
      suffix = "%";
    }

    if (value < 0) {
      value = Math.abs(value);
      prefix = "(";
      suffix = ")" + suffix;
    }

    returnString = BigDecimal.valueOf(value).setScale(1, RoundingMode.HALF_UP).toString();

    return prefix + returnString + suffix;
  }

  public static JsonNode doRequest(
      OpenAiServiceImplementation openAiServiceImplementation,
      List<ChatMessage> messages,
      FunctionExecutor functionExecutor,
      String functionName) {
    ChatCompletionRequest chatCompletionRequest;
    try {
      if (functionExecutor != null) {
        chatCompletionRequest =
            openAiServiceImplementation.getChatRequestWithFunction(
                messages, functionExecutor, functionName);
      } else {
        chatCompletionRequest = openAiServiceImplementation.getChatRequest(messages);
      }
      log.debug("Chat was created");
      ChatCompletionResult chatCompletionResult =
          openAiServiceImplementation
              .getOpenAiService()
              .createChatCompletion(chatCompletionRequest);
      ChatMessage responseMessage = chatCompletionResult.getChoices().get(0).getMessage();
      messages.add(
          responseMessage); // don't forget to update the conversation with the latest response

      ChatFunctionCall functionCall = responseMessage.getFunctionCall();
      if (functionCall != null) {
        if (!functionCall.getArguments().toPrettyString().isEmpty()) {
          System.out.println("Executed " + functionCall.getName() + ".");
          return functionCall.getArguments();
        } else {
          throw new ResourceException(
              HttpStatus.BAD_REQUEST, "Couldn't parse function call to OpenAI");
        }
      } else {
        if (!responseMessage.getContent().isEmpty()) {
          ObjectMapper objectMapper = new ObjectMapper();
          ObjectNode _rootNode = objectMapper.createObjectNode();
          _rootNode.put("response", responseMessage.getContent());
          return _rootNode;
        } else {
          throw new ResourceException(
              HttpStatus.BAD_REQUEST, "Couldn't parse empty response OpenAI");
        }
      }
    } catch (Exception e) {
      log.debug(e);
      throw new ResourceException(HttpStatus.GATEWAY_TIMEOUT, "the OPen AI server did not respond");
    }
  }
}
