/*
 * Decompiled with CFR 0.152.
 */
package cool.cena.openai.context;

import cool.cena.openai.OpenAiApiAccessor;
import cool.cena.openai.autoconfigure.OpenAiProperties;
import cool.cena.openai.exception.chatcompletion.ChatCompletionOutDatedException;
import cool.cena.openai.pojo.chatcompletion.ChatCompletionMessage;
import cool.cena.openai.pojo.chatcompletion.OpenAiChatCompletionRequestBody;
import cool.cena.openai.pojo.chatcompletion.OpenAiChatCompletionResponseBody;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class OpenAiChatCompletionContext {
    private OpenAiApiAccessor apiAccessor;
    private OpenAiChatCompletionRequestBody requestBody;
    private MessageSearchTree messageSearchTree;
    private Version version;
    private Integer maxPromptToken;

    public OpenAiChatCompletionContext(OpenAiApiAccessor apiAccessor, OpenAiProperties.OpenAiChatCompletionProperties openAiChatCompletionProperties) {
        this.apiAccessor = apiAccessor;
        this.requestBody = new OpenAiChatCompletionRequestBody(openAiChatCompletionProperties);
        this.maxPromptToken = openAiChatCompletionProperties.getMaxPromptToken();
        this.messageSearchTree = new MessageSearchTree();
        this.version = new Version(new ArrayList<Integer>());
    }

    public OpenAiChatCompletionContext setModel(String model) {
        this.requestBody.setModel(model);
        return this;
    }

    public OpenAiChatCompletionContext setUser(String user) {
        this.requestBody.setUser(user);
        return this;
    }

    public OpenAiChatCompletionContext setTemperature(Double temperature) {
        this.requestBody.setTemperature(temperature);
        return this;
    }

    public OpenAiChatCompletionContext setTopP(Double topP) {
        this.requestBody.setTopP(topP);
        return this;
    }

    public OpenAiChatCompletionContext setMaxPromptToken(Integer maxPromptToken) {
        this.maxPromptToken = maxPromptToken;
        return this;
    }

    public OpenAiChatCompletionContext setMaxCompletionToken(Integer maxCompletionToken) {
        this.requestBody.setMaxCompletionToken(maxCompletionToken);
        return this;
    }

    public OpenAiChatCompletionContext setN(Integer n) {
        this.requestBody.setN(n);
        return this;
    }

    public OpenAiChatCompletionContext setPresencePenalty(Double presencePenalty) {
        this.requestBody.setPresencePenalty(presencePenalty);
        return this;
    }

    public OpenAiChatCompletionContext setFrequencyPenalty(Double frequencyPenalty) {
        this.requestBody.setFrequencyPenalty(frequencyPenalty);
        return this;
    }

    public OpenAiChatCompletionContext setLogitBias(Map<Integer, Double> logitBias) {
        this.requestBody.setLogitBias(logitBias);
        return this;
    }

    public OpenAiChatCompletionContext setStop(List<String> stop) {
        this.requestBody.setStop(stop);
        return this;
    }

    private void addMessage(ChatCompletionMessage newMessage) {
        Version newVersion;
        this.version = newVersion = this.messageSearchTree.insert(this.version, newMessage);
        this.messageSearchTree.print(this.version);
    }

    private void addMessage(ChatCompletionMessage newMessage, int token) {
        Version newVersion;
        this.version = newVersion = this.messageSearchTree.insert(this.version, newMessage, token);
        this.messageSearchTree.print(this.version);
    }

    public OpenAiChatCompletionContext addSystemMessage(String newMessageContent) {
        ChatCompletionMessage newMessage = new ChatCompletionMessage("system", newMessageContent);
        this.addMessage(newMessage);
        return this;
    }

    public OpenAiChatCompletionContext addUserMessage(String newMessageContent) {
        ChatCompletionMessage newMessage = new ChatCompletionMessage("user", newMessageContent);
        this.addMessage(newMessage);
        return this;
    }

    public OpenAiChatCompletionContext addAssistantMessage(String newMessageContent) {
        ChatCompletionMessage newMessage = new ChatCompletionMessage("assistant", newMessageContent);
        this.addMessage(newMessage);
        return this;
    }

    public Version getVersion() {
        return this.version;
    }

    public OpenAiChatCompletionContext switchVersion(int n) {
        this.version = this.version.translate(n);
        this.messageSearchTree.refresh(this.version);
        this.messageSearchTree.print(this.version);
        return this;
    }

    public OpenAiChatCompletionContext switchVersion(Version version) {
        this.version = version;
        this.messageSearchTree.refresh(this.version);
        return this;
    }

    public OpenAiChatCompletionResponseBody create() {
        Version requestVersion = new Version(this.version);
        PromptMessage promptMessage = this.messageSearchTree.getPromptMessage(this.version, this.maxPromptToken);
        int requestPromptToken = promptMessage.promptToken;
        this.requestBody.setMessages(promptMessage.promptMessages);
        OpenAiChatCompletionResponseBody response = this.apiAccessor.sendRequest(this.requestBody);
        if (this.messageSearchTree.checkLatestVersion(requestVersion)) {
            int responsePromptToken = response.getPromptToken();
            int newPromptToken = responsePromptToken - requestPromptToken;
            this.messageSearchTree.setToken(requestVersion, newPromptToken);
            List<OpenAiChatCompletionResponseBody.OpenAiChatCompletionResponseChoice> responseChoices = response.getChoices();
            if (responseChoices.size() == 1) {
                int responseCompletionToken = response.getCompletionToken();
                ChatCompletionMessage responseMessage = response.getObjectMessage();
                this.addMessage(responseMessage, responseCompletionToken);
            } else {
                this.addMessage(response.getObjectMessage());
                for (int i = 1; i < responseChoices.size(); ++i) {
                    this.messageSearchTree.insert(requestVersion, response.getObjectMessage(i));
                }
            }
            return response;
        }
        throw new ChatCompletionOutDatedException();
    }

    public OpenAiChatCompletionResponseBody create(String newMessageContent) {
        return this.addUserMessage(newMessageContent).create();
    }

    private static class MessageSearchTree {
        private List<MessageNode> root = new ArrayList<MessageNode>();

        private MessageSearchTree() {
        }

        private Version insert(Version version, ChatCompletionMessage message) {
            return this.insert(version, message, 0);
        }

        private Version insert(Version version, ChatCompletionMessage message, int token) {
            List<MessageNode> pathChilds = this.root;
            for (int i = 0; i < version.path.size(); ++i) {
                MessageNode pathNode = pathChilds.get(version.path.get(i));
                pathChilds = pathNode.childs;
            }
            ArrayList<Integer> newPath = new ArrayList<Integer>();
            newPath.addAll(version.path);
            newPath.add(pathChilds.size());
            MessageNode newNode = new MessageNode(message, token);
            pathChilds.add(newNode);
            return new Version(newPath);
        }

        private PromptMessage getPromptMessage(Version version, int maxToken) {
            List<MessageNode> pathChilds = this.root;
            ArrayList<MessageNode> promptNodes = new ArrayList<MessageNode>();
            int cumulativeToken = 0;
            for (int i = 0; i < version.path.size(); ++i) {
                MessageNode pathNode = pathChilds.get(version.path.get(i));
                promptNodes.add(pathNode);
                cumulativeToken += pathNode.token;
                while (cumulativeToken > maxToken) {
                    int reducedToken = ((MessageNode)promptNodes.remove((int)0)).token;
                    cumulativeToken -= reducedToken;
                }
                pathChilds = pathNode.childs;
            }
            ArrayList<ChatCompletionMessage> promptMessages = new ArrayList<ChatCompletionMessage>();
            for (MessageNode promptNode : promptNodes) {
                promptMessages.add(promptNode.message);
            }
            return new PromptMessage(promptMessages, cumulativeToken);
        }

        private boolean checkLatestVersion(Version version) {
            List<MessageNode> pathChilds = this.root;
            for (int i = 0; i < version.path.size(); ++i) {
                MessageNode pathNode = pathChilds.get(version.path.get(i));
                pathChilds = pathNode.childs;
            }
            return pathChilds.size() == 0;
        }

        private void setToken(Version version, int token) {
            MessageNode pathNode = null;
            List<MessageNode> pathChilds = this.root;
            for (int i = 0; i < version.path.size(); ++i) {
                pathNode = pathChilds.get(version.path.get(i));
                pathChilds = pathNode.childs;
            }
            pathNode.token = token;
        }

        private void refresh(Version version) {
            MessageNode pathNode = null;
            List<MessageNode> pathChilds = this.root;
            for (int i = 0; i < version.path.size(); ++i) {
                pathNode = pathChilds.get(version.path.get(i));
                pathChilds = pathNode.childs;
            }
            pathNode.childs = new ArrayList<MessageNode>();
        }

        private void print(Version version) {
            MessageNode pathNode = null;
            List<MessageNode> pathChilds = this.root;
            for (int i = 0; i < version.path.size(); ++i) {
                pathNode = pathChilds.get(version.path.get(i));
                pathChilds = pathNode.childs;
            }
        }

        private static class MessageNode {
            private List<MessageNode> childs = new ArrayList<MessageNode>();
            private ChatCompletionMessage message;
            private int token;

            private MessageNode(ChatCompletionMessage message, int token) {
                this.message = message;
                this.token = token;
            }
        }
    }

    public static class Version {
        private List<Integer> path;

        private Version(List<Integer> path) {
            this.path = path;
        }

        private Version(Version version) {
            this.path = new ArrayList<Integer>();
            for (Integer location : version.path) {
                this.path.add(location);
            }
        }

        private Version translate(int n) {
            ArrayList<Integer> newPath = new ArrayList<Integer>();
            for (int i = 0; i < this.path.size() - 1; ++i) {
                newPath.add(this.path.get(i));
            }
            newPath.add(n);
            return new Version(newPath);
        }

        public String toString() {
            Object returnString = "context version: ";
            for (Integer location : this.path) {
                returnString = (String)returnString + location + "-";
            }
            return returnString;
        }
    }

    private static class PromptMessage {
        private List<ChatCompletionMessage> promptMessages;
        private int promptToken;

        private PromptMessage(List<ChatCompletionMessage> promptMessages, int promptToken) {
            this.promptMessages = promptMessages;
            this.promptToken = promptToken;
        }
    }
}

