/*
 * Decompiled with CFR 0.152.
 */
package rocks.imsofa.ai.puppychatter.openai;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.github.kevinsawicki.http.HttpRequest;
import com.github.victools.jsonschema.generator.OptionPreset;
import com.github.victools.jsonschema.generator.SchemaGenerator;
import com.github.victools.jsonschema.generator.SchemaGeneratorConfig;
import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder;
import com.github.victools.jsonschema.generator.SchemaVersion;
import com.google.gson.Gson;
import java.io.InputStream;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.jxpath.JXPathContext;
import org.slf4j.LoggerFactory;
import rocks.imsofa.ai.puppychatter.BarkCallback;
import rocks.imsofa.ai.puppychatter.Conversation;
import rocks.imsofa.ai.puppychatter.PromptParameters;
import rocks.imsofa.ai.puppychatter.Response;
import rocks.imsofa.ai.puppychatter.cache.CacheService;
import rocks.imsofa.ai.puppychatter.openai.OpenAICompatiblePromptParameters;
import rocks.imsofa.ai.puppychatter.openai.OpenAICompatiblePuppyChatter;
import rocks.imsofa.ai.puppychatter.openai.Tool;
import rocks.imsofa.ai.puppychatter.openai.ToolCallRequest;
import rocks.imsofa.ai.puppychatter.openai.ToolCallRequestConversation;
import rocks.imsofa.ai.puppychatter.openai.ToolCallResponseConversation;

public abstract class OpenAICompatibleInputStreamPuppyChatter<S extends OpenAICompatiblePromptParameters, T extends Response>
extends OpenAICompatiblePuppyChatter<S, T> {
    public OpenAICompatibleInputStreamPuppyChatter() {
    }

    public OpenAICompatibleInputStreamPuppyChatter(String replyRole) {
        super(null, replyRole);
    }

    public OpenAICompatibleInputStreamPuppyChatter(CacheService cacheService, String replyRole) {
        super(cacheService, replyRole);
    }

    @Override
    protected T _bark(String sessionId, List<Conversation> messages, S parameters) throws Exception {
        RequestParameters requestParameters = this.getRequestParameters(sessionId, messages, (OpenAICompatiblePromptParameters)parameters, false);
        this.preprocessMessages(requestParameters);
        this.constructPayloadAndSend(requestParameters);
        StreamResultIterator it = this.getStreamResultIterator(requestParameters);
        return this.processResponseStreamFromLLMProviders(sessionId, it, messages, parameters, null, false);
    }

    @Override
    protected void _bark(String sessionId, List<Conversation> messages, S parameters, BarkCallback<T> callback) throws Exception {
        RequestParameters requestParameters = this.getRequestParameters(sessionId, messages, (OpenAICompatiblePromptParameters)parameters, false);
        this.preprocessMessages(requestParameters);
        if (((OpenAICompatiblePromptParameters)parameters).getJsonSchema() != null || ((OpenAICompatiblePromptParameters)parameters).getTools() != null && ((OpenAICompatiblePromptParameters)parameters).getTools().size() > 0) {
            this.constructPayloadAndSend(requestParameters);
            StreamResultIterator it = this.getStreamResultIterator(requestParameters);
            this.processResponseStreamFromLLMProviders(sessionId, it, messages, parameters, callback, false);
        } else {
            this.constructPayloadAndSend(requestParameters);
            StreamResultIterator it = this.getStreamResultIterator(requestParameters);
            this.processResponseStreamFromLLMProviders(sessionId, it, messages, parameters, callback, true);
        }
    }

    protected void preprocessMessages(RequestParameters requestParameters) throws Exception {
    }

    protected abstract StreamResultIterator getStreamResultIterator(RequestParameters var1) throws Exception;

    protected abstract RequestParameters getRequestParameters(String var1, List<Conversation> var2, OpenAICompatiblePromptParameters var3, boolean var4) throws Exception;

    protected HttpRequest constructPayloadAndSend(RequestParameters requestParameters) throws Exception {
        Gson gson = new Gson();
        Conversation lastConversation = requestParameters.getEffectiveMessages().get(requestParameters.getEffectiveMessages().size() - 1);
        String prompt = lastConversation.getContent().trim();
        String model = requestParameters.getModel();
        LoggerFactory.getLogger(this.getClass()).info("using model: " + model);
        HashMap<String, Boolean> input = new HashMap<String, Boolean>(Map.of("model", model, "messages", requestParameters.getEffectiveMessages(), "stream", requestParameters.isStream()));
        if (requestParameters.getParameters().getJsonSchema() != null) {
            SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12, OptionPreset.PLAIN_JSON);
            configBuilder.forFields().withRequiredCheck(field -> {
                JsonProperty annotation = (JsonProperty)field.getAnnotationConsideringFieldAndGetter(JsonProperty.class);
                return annotation != null && annotation.required();
            });
            Map schemaMap = this.getSchemaMap(requestParameters.getParameters().getJsonSchema());
            schemaMap.put("additionalProperties", false);
            input.put("response_format", (Boolean)((Object)Map.of("type", "json_schema", "json_schema", Map.of("strict", true, "name", requestParameters.getParameters().getJsonSchema().getClass().getSimpleName(), "schema", schemaMap))));
        }
        if (requestParameters.getParameters().getTools() != null) {
            ArrayList tools = new ArrayList();
            for (Tool tool : requestParameters.getParameters().getTools()) {
                HashMap<String, Object> toolMap = new HashMap<String, Object>();
                toolMap.put("type", "function");
                HashMap<String, Object> functionMap = new HashMap<String, Object>();
                functionMap.put("name", tool.getName());
                functionMap.put("description", tool.getDescription());
                functionMap.put("parameters", this.getSchemaMap(tool.getParametersClass()));
                toolMap.put("function", functionMap);
                tools.add(toolMap);
            }
            input.put("tools", (Boolean)((Object)tools));
        }
        HttpRequest request = requestParameters.getHttpRequest();
        String body = gson.toJson(input);
        request.send((CharSequence)body);
        return request;
    }

    private Map getSchemaMap(Class schemaClass) {
        SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12, OptionPreset.PLAIN_JSON);
        configBuilder.forFields().withRequiredCheck(field -> {
            JsonProperty annotation = (JsonProperty)field.getAnnotationConsideringFieldAndGetter(JsonProperty.class);
            return annotation != null && annotation.required();
        });
        SchemaGeneratorConfig config = configBuilder.build();
        Gson gson = new Gson();
        SchemaGenerator generator = new SchemaGenerator(config);
        ObjectNode jsonSchema = generator.generateSchema((Type)schemaClass, new Type[0]);
        String jsonString = jsonSchema.toString();
        HashMap schemaMap = new HashMap((Map)gson.fromJson(jsonString, Map.class));
        return schemaMap;
    }

    protected T processResponseStreamFromLLMProviders(String sessionId, StreamResultIterator it, List<Conversation> messages, S parameters, BarkCallback<T> callback, boolean streamed) throws Exception {
        block16: {
            Response response = new Response();
            response.setLastPrompt(List.copyOf(messages));
            response.setMessage("");
            String ret = null;
            Gson gson = new Gson();
            StringBuilder sb = new StringBuilder();
            try {
                while (it.hasNext()) {
                    String line = it.next();
                    if (callback != null && streamed) {
                        Response chunk = new Response();
                        Map result = (Map)gson.fromJson(line, Map.class);
                        JXPathContext context = JXPathContext.newContext((Object)result);
                        String output = (String)context.getValue("choices[1]/delta/content");
                        chunk.setMessage(output);
                        response.setMessage(response.getMessage() + output);
                        callback.responseChunkReceived(chunk, false);
                    }
                    sb.append(line).append("\r\n");
                }
                response.setMessage(sb.toString());
                if (callback != null && streamed) {
                    Response chunk = new Response();
                    chunk.setMessage(null);
                    callback.responseChunkReceived(chunk, true);
                    callback.finalVerificationResult(((PromptParameters)parameters).getResponseVerifier().verify(response));
                    break block16;
                }
                ret = sb.toString();
                LoggerFactory.getLogger(this.getClass()).debug("original returned message: " + ret);
                Map result = (Map)gson.fromJson(ret, Map.class);
                JXPathContext context = JXPathContext.newContext((Object)result);
                String output = (String)context.getValue("choices[1]/message/content");
                if (((OpenAICompatiblePromptParameters)parameters).getTools() != null) {
                    List toolCalls;
                    HashMap toolCallRequestMap = (HashMap)this.toolCallRequests.get(sessionId);
                    if (toolCallRequestMap == null) {
                        toolCallRequestMap = new HashMap();
                        this.toolCallRequests.put(sessionId, toolCallRequestMap);
                    }
                    if ((toolCalls = (List)context.getValue("choices[1]/message/tool_calls")) != null && toolCalls.size() > 0) {
                        ToolCallRequestConversation toolCallConversation = new ToolCallRequestConversation();
                        toolCallConversation.setRole("assistant");
                        toolCallConversation.setContent("");
                        ArrayList<ToolCallRequest> toolCallRequests = new ArrayList<ToolCallRequest>();
                        toolCallConversation.setTool_calls(toolCallRequests);
                        ToolCallRequest processingToolCallRequest = null;
                        int i = 0;
                        while (i < toolCalls.size()) {
                            Map toolCall = (Map)toolCalls.get(i);
                            ToolCallRequest toolCallObj = null;
                            if (!toolCallRequestMap.containsKey(toolCall.get("id"))) {
                                toolCallObj = new ToolCallRequest();
                                toolCallObj.setIndex(i);
                                toolCallObj.setId((String)toolCall.get("id"));
                                Map functionMap = (Map)toolCall.get("function");
                                ToolCallRequest.Function function = new ToolCallRequest.Function();
                                function.setName((String)functionMap.get("name"));
                                function.setArguments((String)functionMap.get("arguments"));
                                toolCallObj.setFunction(function);
                            } else {
                                toolCallObj = (ToolCallRequest)toolCallRequestMap.get(toolCall.get("id"));
                            }
                            toolCallRequests.add(toolCallObj);
                            if (toolCallObj.getResponse() == null) {
                                processingToolCallRequest = toolCallObj;
                                break;
                            }
                            ++i;
                        }
                        if (processingToolCallRequest != null) {
                            if (this.toolCallProcessor != null) {
                                String toolCallResponse = this.toolCallProcessor.processToolCallRequest(processingToolCallRequest);
                                processingToolCallRequest.setResponse(toolCallResponse);
                                ToolCallResponseConversation toolCallResponseConversation = new ToolCallResponseConversation("tool", processingToolCallRequest.getResponse(), processingToolCallRequest.getId(), processingToolCallRequest.getFunction().getName());
                                messages.add(toolCallConversation);
                                messages.add(toolCallResponseConversation);
                                T newResponse = this._bark(sessionId, messages, parameters);
                                output = ((Response)newResponse).getMessage();
                            } else {
                                LoggerFactory.getLogger(this.getClass()).warn("tool call processor is not set");
                                throw new Exception("tool call exists but tool call processor is not set");
                            }
                        }
                    }
                }
                response.setError(false);
                response.setErrorMessage(null);
                response.setMessage(output);
                if (callback != null && !streamed) {
                    callback.responseChunkReceived(response, true);
                    callback.finalVerificationResult(((PromptParameters)parameters).getResponseVerifier().verify(response));
                }
                return (T)response;
            }
            catch (Exception e) {
                e.printStackTrace();
                LoggerFactory.getLogger(this.getClass()).error(e.getMessage() + "\r\n" + ret);
                response.setError(true);
                response.setErrorMessage(e.getMessage());
                return (T)response;
            }
        }
        return null;
    }

    public static class RequestParameters {
        private String sessionId;
        private String model;
        private List<Conversation> effectiveMessages;
        private List<Conversation> originalMessages;
        private OpenAICompatiblePromptParameters parameters;
        private boolean stream;
        private HttpRequest httpRequest;

        public int hashCode() {
            int prime = 31;
            int result = 1;
            result = 31 * result + (this.sessionId == null ? 0 : this.sessionId.hashCode());
            result = 31 * result + (this.model == null ? 0 : this.model.hashCode());
            result = 31 * result + (this.effectiveMessages == null ? 0 : this.effectiveMessages.hashCode());
            result = 31 * result + (this.originalMessages == null ? 0 : this.originalMessages.hashCode());
            result = 31 * result + (this.parameters == null ? 0 : this.parameters.hashCode());
            result = 31 * result + (this.stream ? 1231 : 1237);
            result = 31 * result + (this.httpRequest == null ? 0 : this.httpRequest.hashCode());
            return result;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            RequestParameters other = (RequestParameters)obj;
            if (this.sessionId == null ? other.sessionId != null : !this.sessionId.equals(other.sessionId)) {
                return false;
            }
            if (this.model == null ? other.model != null : !this.model.equals(other.model)) {
                return false;
            }
            if (this.effectiveMessages == null ? other.effectiveMessages != null : !this.effectiveMessages.equals(other.effectiveMessages)) {
                return false;
            }
            if (this.originalMessages == null ? other.originalMessages != null : !this.originalMessages.equals(other.originalMessages)) {
                return false;
            }
            if (this.parameters == null ? other.parameters != null : !this.parameters.equals(other.parameters)) {
                return false;
            }
            if (this.stream != other.stream) {
                return false;
            }
            return !(this.httpRequest == null ? other.httpRequest != null : !this.httpRequest.equals(other.httpRequest));
        }

        public String getModel() {
            return this.model;
        }

        public void setModel(String model) {
            this.model = model;
        }

        public List<Conversation> getEffectiveMessages() {
            return this.effectiveMessages;
        }

        public void setEffectiveMessages(List<Conversation> effectiveMessages) {
            this.effectiveMessages = effectiveMessages;
        }

        public List<Conversation> getOriginalMessages() {
            return this.originalMessages;
        }

        public void setOriginalMessages(List<Conversation> originalMessages) {
            this.originalMessages = originalMessages;
        }

        public OpenAICompatiblePromptParameters getParameters() {
            return this.parameters;
        }

        public void setParameters(OpenAICompatiblePromptParameters parameters) {
            this.parameters = parameters;
        }

        public boolean isStream() {
            return this.stream;
        }

        public void setStream(boolean stream) {
            this.stream = stream;
        }

        public HttpRequest getHttpRequest() {
            return this.httpRequest;
        }

        public void setHttpRequest(HttpRequest httpRequest) {
            this.httpRequest = httpRequest;
        }

        public String getSessionId() {
            return this.sessionId;
        }

        public void setSessionId(String sessionId) {
            this.sessionId = sessionId;
        }
    }

    public static abstract class StreamResultIterator {
        protected InputStream inputStream;

        public StreamResultIterator(InputStream inputStream) {
            this.inputStream = inputStream;
        }

        public abstract String next();

        public abstract boolean hasNext();
    }
}

