package rocks.imsofa.ai.puppychatter.openai;

import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.apache.commons.jxpath.JXPathContext;
import org.checkerframework.checker.units.qual.C;
import org.slf4j.LoggerFactory;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.JsonNode;
import com.github.kevinsawicki.http.HttpRequest;
import com.github.victools.jsonschema.generator.OptionPreset;
import com.github.victools.jsonschema.generator.SchemaGenerator;
import com.github.victools.jsonschema.generator.SchemaGeneratorConfig;
import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder;
import com.github.victools.jsonschema.generator.SchemaVersion;
import com.google.gson.Gson;
import com.hankcs.hanlp.dependency.nnparser.util.Log;

import rocks.imsofa.ai.puppychatter.BarkCallback;
import rocks.imsofa.ai.puppychatter.Conversation;
import rocks.imsofa.ai.puppychatter.Response;
import rocks.imsofa.ai.puppychatter.cache.CacheService;

/**
 * an implementation of OpenAICompatiblePromptParameters that uses an
 * InputStream to process the response
 */
@SuppressWarnings("all")
public abstract class OpenAICompatibleInputStreamPuppyChatter<S extends OpenAICompatiblePromptParameters, T extends Response>
        extends OpenAICompatiblePuppyChatter<S, T> {
    public OpenAICompatibleInputStreamPuppyChatter() {
        super();
    }

    public OpenAICompatibleInputStreamPuppyChatter(String replyRole) {
        super(null, replyRole);
    }

    public OpenAICompatibleInputStreamPuppyChatter(CacheService cacheService, String replyRole) {
        super(cacheService, replyRole);
    }

    @Override
    protected T _bark(String sessionId, List<Conversation> messages, S parameters) throws Exception {
        RequestParameters requestParameters = getRequestParameters(sessionId, messages, parameters, false);
        preprocessMessages(requestParameters);
        constructPayloadAndSend(requestParameters);
        StreamResultIterator it = getStreamResultIterator(requestParameters);
        return processResponseStreamFromLLMProviders(sessionId, it, messages, parameters, null, false);
    }

    @Override
    protected void _bark(String sessionId, List<Conversation> messages, S parameters, BarkCallback<T> callback)
            throws Exception {
        RequestParameters requestParameters = getRequestParameters(sessionId, messages, parameters, false);
        preprocessMessages(requestParameters);

        if (parameters.getJsonSchema() != null || (parameters.getTools() != null && parameters.getTools().size() > 0)) {
            // in these cases streaming is very difficult
            // RequestParameters requestParameters = getRequestParameters(sessionId, messages, parameters, false);
            constructPayloadAndSend(requestParameters);
            StreamResultIterator it = getStreamResultIterator(requestParameters);
            processResponseStreamFromLLMProviders(sessionId, it, messages, parameters, callback, false);
        } else {
            // RequestParameters requestParameters = getRequestParameters(sessionId, messages, parameters, true);
            constructPayloadAndSend(requestParameters);
            StreamResultIterator it = getStreamResultIterator(requestParameters);
            processResponseStreamFromLLMProviders(sessionId, it, messages, parameters, callback, true);
        }

    }

    protected void preprocessMessages(RequestParameters requestParameters)throws Exception {
        //default do nothing
        return;
    }

    protected abstract StreamResultIterator getStreamResultIterator(RequestParameters requestParameters) throws Exception;

    
    protected abstract RequestParameters getRequestParameters(String sessionId, List<Conversation> messages,
            OpenAICompatiblePromptParameters parameters, boolean streamed) throws Exception;

    protected HttpRequest constructPayloadAndSend(RequestParameters requestParameters) throws Exception {
        // RequestParameters requestParameters = getRequestParameters(sessionId, messages, parameters, stream);
        Gson gson = new Gson();
        Conversation lastConversation = requestParameters.getEffectiveMessages().get(requestParameters.getEffectiveMessages().size() - 1);
        String prompt = lastConversation.getContent().trim();
        String model=requestParameters.getModel();
        LoggerFactory.getLogger(getClass()).info("using model: " + model);

        // process conversation, remove the model: magic string
        /*messages = (List<Conversation>) messages.stream().map(Conversation::clone).collect(Collectors.toList());
        for (Conversation c : messages) {
            // System.out.println(c.getClass());
            String content = c.getContent().trim();
            if (content.startsWith("model:")) {
                int index = content.indexOf(" ");
                content = content.substring(index + 1);
                c.setContent(content);
            }
        }*/
        Map input = new HashMap(Map.of(
                "model", model,
                "messages", requestParameters.getEffectiveMessages(),
                "stream", requestParameters.isStream()));
        /*
         * if the jsonSchema is not null, we will use it to generate a json schema
         * the current implementation uses victools/json-schema-generator
         */
        if (requestParameters.getParameters().getJsonSchema() != null) {
            SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12,
                    OptionPreset.PLAIN_JSON);
            configBuilder.forFields().withRequiredCheck((field) -> {
                JsonProperty annotation = field.getAnnotationConsideringFieldAndGetter(JsonProperty.class);
                return annotation != null && annotation.required();
            });
            Map schemaMap = getSchemaMap(requestParameters.getParameters().getJsonSchema());

            schemaMap.put("additionalProperties", false);
            input.put("response_format", Map.of(
                    "type", "json_schema",
                    "json_schema", Map.of(
                            "strict", true,
                            "name", requestParameters.getParameters().getJsonSchema().getClass().getSimpleName(),
                            "schema", schemaMap)));
        }

        if (requestParameters.getParameters().getTools() != null) {
            List<Map> tools = new ArrayList<>();
            for (Tool tool : requestParameters.getParameters().getTools()) {
                Map toolMap = new HashMap();
                toolMap.put("type", "function");
                Map functionMap = new HashMap();
                functionMap.put("name", tool.getName());
                functionMap.put("description", tool.getDescription());
                functionMap.put("parameters", this.getSchemaMap(tool.getParametersClass()));
                toolMap.put("function", functionMap);
                tools.add(toolMap);
            }
            input.put("tools", tools);
        }

        // System.out.println("input=" + gson.toJson(input));
        HttpRequest request = requestParameters.getHttpRequest();
        String body = gson.toJson(input);
        request.send(body);
        return request;
    }

    private Map getSchemaMap(Class schemaClass) {
        SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12,
                OptionPreset.PLAIN_JSON);
        configBuilder.forFields().withRequiredCheck((field) -> {
            JsonProperty annotation = field.getAnnotationConsideringFieldAndGetter(JsonProperty.class);
            return annotation != null && annotation.required();
        });
        SchemaGeneratorConfig config = configBuilder.build();
        Gson gson = new Gson();
        SchemaGenerator generator = new SchemaGenerator(config);
        JsonNode jsonSchema = generator.generateSchema(schemaClass);
        String jsonString = jsonSchema.toString();
        Map schemaMap = new HashMap(gson.fromJson(jsonString, Map.class));
        return schemaMap;
    }

    protected T processResponseStreamFromLLMProviders(
            String sessionId,
            StreamResultIterator it,
            List<Conversation> messages,
            S parameters,
            BarkCallback<T> callback,
            boolean streamed) throws Exception {
        Response response = new Response();
        response.setLastPrompt(List.copyOf(messages));
        response.setMessage("");
        String ret = null;
        Gson gson = new Gson();
        StringBuilder sb = new StringBuilder();
        try {
            while (it.hasNext()) {
                String line = it.next();
                // System.out.println("line="+line);
                if (callback != null && streamed) {
                    // System.out.println("line="+line);
                    // streamed response
                    Response chunk = new Response();
                    Map result = gson.fromJson(line, Map.class);
                    JXPathContext context = JXPathContext.newContext(result);
                    String output = (String) context.getValue("choices[1]/delta/content");
                    chunk.setMessage(output);
                    response.setMessage(response.getMessage() + output);
                    callback.responseChunkReceived((T) chunk, false);
                }
                // non-streamed response and final verification
                sb.append(line).append("\r\n");

            }
            response.setMessage(sb.toString());
            if (callback != null && streamed) {
                // System.out.println("streamed");
                // streamed response
                Response chunk = new Response();
                chunk.setMessage(null);
                callback.responseChunkReceived((T) chunk, true);
                callback.finalVerificationResult(parameters.getResponseVerifier().verify(response));
            } else {
                // System.out.println("non streamed");
                // non-streamed response
                ret = sb.toString();
                // System.out.println("original returned message: " + ret);
                LoggerFactory.getLogger(this.getClass()).debug("original returned message: " + ret);
                Map result = gson.fromJson(ret, Map.class);
                JXPathContext context = JXPathContext.newContext(result);
                String output = (String) context.getValue("choices[1]/message/content");
                if (parameters.getTools() != null) {
                    Map<String, ToolCallRequest> toolCallRequestMap = this.toolCallRequests.get(sessionId);
                    if (toolCallRequestMap == null) {
                        toolCallRequestMap = new HashMap<>();
                        toolCallRequests.put(sessionId, toolCallRequestMap);
                    }
                    List<Map> toolCalls = (List) context.getValue("choices[1]/message/tool_calls");
                    if (toolCalls != null && toolCalls.size() > 0) {
                        // create a ToolCallRequestConversation
                        ToolCallRequestConversation toolCallConversation = new ToolCallRequestConversation();
                        toolCallConversation.setRole("assistant");
                        toolCallConversation.setContent("");

                        // tool calls
                        List<ToolCallRequest> toolCallRequests = new ArrayList<>();
                        toolCallConversation.setTool_calls(toolCallRequests);
                        ToolCallRequest processingToolCallRequest = null;
                        for (int i = 0; i < toolCalls.size(); i++) {
                            Map toolCall = toolCalls.get(i);
                            ToolCallRequest toolCallObj = null;
                            if (toolCallRequestMap.containsKey(toolCall.get("id")) == false) {
                                // the tool call has not already been registered
                                toolCallObj = new ToolCallRequest();
                                toolCallObj.setIndex(i);
                                toolCallObj.setId((String) toolCall.get("id"));
                                Map functionMap = (Map) toolCall.get("function");
                                ToolCallRequest.Function function = new ToolCallRequest.Function();
                                function.setName((String) functionMap.get("name"));
                                // function.setArguments((Map) gson.fromJson((String)
                                // functionMap.get("arguments"), Map.class));
                                function.setArguments((String) functionMap.get("arguments"));
                                toolCallObj.setFunction(function);
                            } else {
                                toolCallObj = toolCallRequestMap.get(toolCall.get("id"));
                            }
                            toolCallRequests.add(toolCallObj);
                            if (toolCallObj.getResponse() == null) {
                                // the tool call has not yet been processed
                                processingToolCallRequest = toolCallObj;
                                break;
                            }
                        }
                        // process tool calls
                        if (processingToolCallRequest != null) {
                            if (toolCallProcessor != null) {
                                String toolCallResponse = toolCallProcessor
                                        .processToolCallRequest(processingToolCallRequest);
                                processingToolCallRequest.setResponse(toolCallResponse);
                                ToolCallResponseConversation toolCallResponseConversation = new ToolCallResponseConversation(
                                        "tool",
                                        processingToolCallRequest.getResponse(),
                                        processingToolCallRequest.getId(),
                                        processingToolCallRequest.getFunction().getName());
                                messages.add(toolCallConversation);
                                messages.add(toolCallResponseConversation);
                                Response newResponse = this._bark(sessionId, messages, parameters);
                                // System.out.println(newResponse.getMessage());
                                output = newResponse.getMessage();
                            } else {
                                LoggerFactory.getLogger(getClass()).warn("tool call processor is not set");
                                throw new Exception("tool call exists but tool call processor is not set");
                            }
                        }

                    }
                }

                // String finishReason = (String) context.getValue("choices[1]/finish_reason");
                // Map message = (Map) context.getValue("choices[1]/message");

                response.setError(false);
                response.setErrorMessage(null);

                response.setMessage(output);
                if (callback != null && streamed == false) {
                    callback.responseChunkReceived((T) response, true);
                    callback.finalVerificationResult(parameters.getResponseVerifier().verify(response));
                }
                return (T) response;
            }
        } catch (Exception e) {
            e.printStackTrace();
            LoggerFactory.getLogger(getClass()).error(e.getMessage() + "\r\n" + ret);
            response.setError(true);
            response.setErrorMessage(e.getMessage());
            return (T) response;
        }

        return null;
    }

    public static abstract class StreamResultIterator {
        protected InputStream inputStream;

        public StreamResultIterator(InputStream inputStream) {
            this.inputStream = inputStream;
        }

        public abstract String next();

        public abstract boolean hasNext();
    }

    /**
     * sometimes, it may be necessary to transform the original messages,
     * to fulfill the requirements, use this class as a bridge between
     * the original prompts and the effective prompts
     */
    public static class RequestParameters{
        private String sessionId;
        private String model;
        private List<Conversation> effectiveMessages;
        private List<Conversation> originalMessages;
        private OpenAICompatiblePromptParameters parameters;
        private boolean stream;
        private HttpRequest httpRequest;

        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result + ((sessionId == null) ? 0 : sessionId.hashCode());
            result = prime * result + ((model == null) ? 0 : model.hashCode());
            result = prime * result + ((effectiveMessages == null) ? 0 : effectiveMessages.hashCode());
            result = prime * result + ((originalMessages == null) ? 0 : originalMessages.hashCode());
            result = prime * result + ((parameters == null) ? 0 : parameters.hashCode());
            result = prime * result + (stream ? 1231 : 1237);
            result = prime * result + ((httpRequest == null) ? 0 : httpRequest.hashCode());
            return result;
        }
        @Override
        public boolean equals(Object obj) {
            if (this == obj)
                return true;
            if (obj == null)
                return false;
            if (getClass() != obj.getClass())
                return false;
            RequestParameters other = (RequestParameters) obj;
            if (sessionId == null) {
                if (other.sessionId != null)
                    return false;
            } else if (!sessionId.equals(other.sessionId))
                return false;
            if (model == null) {
                if (other.model != null)
                    return false;
            } else if (!model.equals(other.model))
                return false;
            if (effectiveMessages == null) {
                if (other.effectiveMessages != null)
                    return false;
            } else if (!effectiveMessages.equals(other.effectiveMessages))
                return false;
            if (originalMessages == null) {
                if (other.originalMessages != null)
                    return false;
            } else if (!originalMessages.equals(other.originalMessages))
                return false;
            if (parameters == null) {
                if (other.parameters != null)
                    return false;
            } else if (!parameters.equals(other.parameters))
                return false;
            if (stream != other.stream)
                return false;
            if (httpRequest == null) {
                if (other.httpRequest != null)
                    return false;
            } else if (!httpRequest.equals(other.httpRequest))
                return false;
            return true;
        }
        public String getModel() {
            return model;
        }
        public void setModel(String model) {
            this.model = model;
        }
        public List<Conversation> getEffectiveMessages() {
            return effectiveMessages;
        }
        public void setEffectiveMessages(List<Conversation> effectiveMessages) {
            this.effectiveMessages = effectiveMessages;
        }
        public List<Conversation> getOriginalMessages() {
            return originalMessages;
        }
        public void setOriginalMessages(List<Conversation> originalMessages) {
            this.originalMessages = originalMessages;
        }
        public OpenAICompatiblePromptParameters getParameters() {
            return parameters;
        }
        public void setParameters(OpenAICompatiblePromptParameters parameters) {
            this.parameters = parameters;
        }
        public boolean isStream() {
            return stream;
        }
        public void setStream(boolean stream) {
            this.stream = stream;
        }
        public HttpRequest getHttpRequest() {
            return httpRequest;
        }
        public void setHttpRequest(HttpRequest httpRequest) {
            this.httpRequest = httpRequest;
        }
        public String getSessionId() {
            return sessionId;
        }
        public void setSessionId(String sessionId) {
            this.sessionId = sessionId;
        }
       
    }
}
