/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.service;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.chat.Capability;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ChatRequestParameters;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.ResponseFormatType;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.input.structured.StructuredPrompt;
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
import dev.langchain4j.model.moderation.Moderation;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServiceTokenStream;
import dev.langchain4j.service.AiServiceTokenStreamParameters;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.ChatMemoryAccess;
import dev.langchain4j.service.IllegalConfigurationException;
import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.Moderate;
import dev.langchain4j.service.Result;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.TypeUtils;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.service.UserName;
import dev.langchain4j.service.V;
import dev.langchain4j.service.output.ServiceOutputParser;
import dev.langchain4j.service.tool.ToolExecutionContext;
import dev.langchain4j.service.tool.ToolExecutionResult;
import dev.langchain4j.spi.ServiceHelper;
import dev.langchain4j.spi.services.TokenStreamAdapter;
import java.io.InputStream;
import java.lang.reflect.Array;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.Proxy;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Scanner;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

class DefaultAiServices<T>
extends AiServices<T> {
    private final ServiceOutputParser serviceOutputParser = new ServiceOutputParser();
    private final Collection<TokenStreamAdapter> tokenStreamAdapters = ServiceHelper.loadFactories(TokenStreamAdapter.class);

    DefaultAiServices(AiServiceContext context) {
        super(context);
    }

    static void validateParameters(Method method) {
        Parameter[] parameters = method.getParameters();
        if (parameters == null || parameters.length < 2) {
            return;
        }
        for (Parameter parameter : parameters) {
            V v = parameter.getAnnotation(V.class);
            UserMessage userMessage = parameter.getAnnotation(UserMessage.class);
            MemoryId memoryId = parameter.getAnnotation(MemoryId.class);
            UserName userName = parameter.getAnnotation(UserName.class);
            if (v != null || userMessage != null || memoryId != null || userName != null) continue;
            throw IllegalConfigurationException.illegalConfiguration("Parameter '%s' of method '%s' should be annotated with @V or @UserMessage or @UserName or @MemoryId", parameter.getName(), method.getName());
        }
    }

    @Override
    public T build() {
        this.performBasicValidation();
        if (!this.context.hasChatMemory() && ChatMemoryAccess.class.isAssignableFrom(this.context.aiServiceClass)) {
            throw IllegalConfigurationException.illegalConfiguration("In order to have a service implementing ChatMemoryAccess, please configure the ChatMemoryProvider on the '%s'.", this.context.aiServiceClass.getName());
        }
        for (Method method : this.context.aiServiceClass.getMethods()) {
            if (method.isAnnotationPresent(Moderate.class) && this.context.moderationModel == null) {
                throw IllegalConfigurationException.illegalConfiguration("The @Moderate annotation is present, but the moderationModel is not set up. Please ensure a valid moderationModel is configured before using the @Moderate annotation.");
            }
            if (method.getReturnType() == Result.class || method.getReturnType() == List.class || method.getReturnType() == Set.class) {
                TypeUtils.validateReturnTypesAreProperlyParametrized(method.getName(), method.getGenericReturnType());
            }
            if (this.context.hasChatMemory()) continue;
            for (Parameter parameter : method.getParameters()) {
                if (!parameter.isAnnotationPresent(MemoryId.class)) continue;
                throw IllegalConfigurationException.illegalConfiguration("In order to use @MemoryId, please configure the ChatMemoryProvider on the '%s'.", this.context.aiServiceClass.getName());
            }
        }
        Object proxyInstance = Proxy.newProxyInstance(this.context.aiServiceClass.getClassLoader(), new Class[]{this.context.aiServiceClass}, new InvocationHandler(){
            private final ExecutorService executor = Executors.newCachedThreadPool();

            @Override
            public Object invoke(Object proxy, Method method, Object[] args) throws Exception {
                ArrayList<ChatMessage> messages;
                Type returnType;
                if (method.getDeclaringClass() == Object.class) {
                    return method.invoke((Object)this, args);
                }
                if (method.getDeclaringClass() == ChatMemoryAccess.class) {
                    return switch (method.getName()) {
                        case "getChatMemory" -> DefaultAiServices.this.context.chatMemoryService.getChatMemory(args[0]);
                        case "evictChatMemory" -> Boolean.valueOf(DefaultAiServices.this.context.chatMemoryService.evictChatMemory(args[0]) != null);
                        default -> throw new UnsupportedOperationException("Unknown method on ChatMemoryAccess class : " + method.getName());
                    };
                }
                DefaultAiServices.validateParameters(method);
                Object memoryId = DefaultAiServices.findMemoryId(method, args).orElse("default");
                ChatMemory chatMemory = DefaultAiServices.this.context.hasChatMemory() ? DefaultAiServices.this.context.chatMemoryService.getOrCreateChatMemory(memoryId) : null;
                Optional<dev.langchain4j.data.message.SystemMessage> systemMessage = DefaultAiServices.this.prepareSystemMessage(memoryId, method, args);
                dev.langchain4j.data.message.UserMessage userMessage = DefaultAiServices.prepareUserMessage(method, args);
                AugmentationResult augmentationResult = null;
                if (DefaultAiServices.this.context.retrievalAugmentor != null) {
                    List chatMemoryMessages = chatMemory != null ? chatMemory.messages() : null;
                    Metadata metadata = Metadata.from((ChatMessage)userMessage, (Object)memoryId, (List)chatMemoryMessages);
                    AugmentationRequest augmentationRequest = new AugmentationRequest((ChatMessage)userMessage, metadata);
                    augmentationResult = DefaultAiServices.this.context.retrievalAugmentor.augment(augmentationRequest);
                    userMessage = (dev.langchain4j.data.message.UserMessage)augmentationResult.chatMessage();
                }
                boolean streaming = (returnType = method.getGenericReturnType()) == TokenStream.class || this.canAdaptTokenStreamTo(returnType);
                boolean supportsJsonSchema = this.supportsJsonSchema();
                Optional<Object> jsonSchema = Optional.empty();
                if (supportsJsonSchema && !streaming) {
                    jsonSchema = DefaultAiServices.this.serviceOutputParser.jsonSchema(returnType);
                }
                if (!(supportsJsonSchema && !jsonSchema.isEmpty() || streaming)) {
                    userMessage = this.appendOutputFormatInstructions(returnType, userMessage);
                }
                if (chatMemory != null) {
                    systemMessage.ifPresent(arg_0 -> ((ChatMemory)chatMemory).add(arg_0));
                    chatMemory.add((ChatMessage)userMessage);
                    messages = chatMemory.messages();
                } else {
                    messages = new ArrayList<dev.langchain4j.data.message.UserMessage>();
                    systemMessage.ifPresent(messages::add);
                    messages.add((ChatMessage)userMessage);
                }
                Future<Moderation> moderationFuture = this.triggerModerationIfNeeded(method, messages);
                ToolExecutionContext toolExecutionContext = DefaultAiServices.this.context.toolService.executionContext(memoryId, userMessage);
                if (streaming) {
                    AiServiceTokenStream tokenStream = new AiServiceTokenStream(AiServiceTokenStreamParameters.builder().messages(messages).toolSpecifications(toolExecutionContext.toolSpecifications()).toolExecutors(toolExecutionContext.toolExecutors()).retrievedContents(augmentationResult != null ? augmentationResult.contents() : null).context(DefaultAiServices.this.context).memoryId(memoryId).build());
                    if (returnType == TokenStream.class) {
                        return tokenStream;
                    }
                    return this.adapt(tokenStream, returnType);
                }
                ResponseFormat responseFormat = null;
                if (supportsJsonSchema && jsonSchema.isPresent()) {
                    responseFormat = ResponseFormat.builder().type(ResponseFormatType.JSON).jsonSchema((JsonSchema)jsonSchema.get()).build();
                }
                ChatRequestParameters parameters = ChatRequestParameters.builder().toolSpecifications(toolExecutionContext.toolSpecifications()).responseFormat(responseFormat).build();
                ChatRequest chatRequest = ChatRequest.builder().messages(messages).parameters(parameters).build();
                ChatResponse chatResponse = DefaultAiServices.this.context.chatModel.chat(chatRequest);
                AiServices.verifyModerationIfNeeded(moderationFuture);
                ToolExecutionResult toolExecutionResult = DefaultAiServices.this.context.toolService.executeInferenceAndToolsLoop(chatResponse, parameters, messages, DefaultAiServices.this.context.chatModel, chatMemory, memoryId, toolExecutionContext.toolExecutors());
                chatResponse = toolExecutionResult.chatResponse();
                FinishReason finishReason = chatResponse.metadata().finishReason();
                Response response = Response.from((Object)chatResponse.aiMessage(), (TokenUsage)toolExecutionResult.tokenUsageAccumulator(), (FinishReason)finishReason);
                Object parsedResponse = DefaultAiServices.this.serviceOutputParser.parse((Response<AiMessage>)response, returnType);
                if (TypeUtils.typeHasRawClass(returnType, Result.class)) {
                    return Result.builder().content(parsedResponse).tokenUsage(toolExecutionResult.tokenUsageAccumulator()).sources(augmentationResult == null ? null : augmentationResult.contents()).finishReason(finishReason).toolExecutions(toolExecutionResult.toolExecutions()).build();
                }
                return parsedResponse;
            }

            private boolean canAdaptTokenStreamTo(Type returnType) {
                for (TokenStreamAdapter tokenStreamAdapter : DefaultAiServices.this.tokenStreamAdapters) {
                    if (!tokenStreamAdapter.canAdaptTokenStreamTo(returnType)) continue;
                    return true;
                }
                return false;
            }

            private Object adapt(TokenStream tokenStream, Type returnType) {
                for (TokenStreamAdapter tokenStreamAdapter : DefaultAiServices.this.tokenStreamAdapters) {
                    if (!tokenStreamAdapter.canAdaptTokenStreamTo(returnType)) continue;
                    return tokenStreamAdapter.adapt(tokenStream);
                }
                throw new IllegalStateException("Can't find suitable TokenStreamAdapter");
            }

            private boolean supportsJsonSchema() {
                return DefaultAiServices.this.context.chatModel != null && DefaultAiServices.this.context.chatModel.supportedCapabilities().contains(Capability.RESPONSE_FORMAT_JSON_SCHEMA);
            }

            private dev.langchain4j.data.message.UserMessage appendOutputFormatInstructions(Type returnType, dev.langchain4j.data.message.UserMessage userMessage) {
                String outputFormatInstructions = DefaultAiServices.this.serviceOutputParser.outputFormatInstructions(returnType);
                String text = userMessage.singleText() + outputFormatInstructions;
                userMessage = Utils.isNotNullOrBlank((String)userMessage.name()) ? dev.langchain4j.data.message.UserMessage.from((String)userMessage.name(), (String)text) : dev.langchain4j.data.message.UserMessage.from((String)text);
                return userMessage;
            }

            private Future<Moderation> triggerModerationIfNeeded(Method method, List<ChatMessage> messages) {
                if (method.isAnnotationPresent(Moderate.class)) {
                    return this.executor.submit(() -> {
                        List<ChatMessage> messagesToModerate = AiServices.removeToolMessages(messages);
                        return (Moderation)DefaultAiServices.this.context.moderationModel.moderate(messagesToModerate).content();
                    });
                }
                return null;
            }
        });
        return (T)proxyInstance;
    }

    private Optional<dev.langchain4j.data.message.SystemMessage> prepareSystemMessage(Object memoryId, Method method, Object[] args) {
        return this.findSystemMessageTemplate(memoryId, method).map(systemMessageTemplate -> PromptTemplate.from((String)systemMessageTemplate).apply(DefaultAiServices.findTemplateVariables(systemMessageTemplate, method, args)).toSystemMessage());
    }

    private Optional<String> findSystemMessageTemplate(Object memoryId, Method method) {
        SystemMessage annotation = method.getAnnotation(SystemMessage.class);
        if (annotation != null) {
            return Optional.of(DefaultAiServices.getTemplate(method, "System", annotation.fromResource(), annotation.value(), annotation.delimiter()));
        }
        return this.context.systemMessageProvider.apply(memoryId);
    }

    private static Map<String, Object> findTemplateVariables(String template, Method method, Object[] args) {
        Parameter[] parameters = method.getParameters();
        HashMap<String, Object> variables = new HashMap<String, Object>();
        for (int i = 0; i < parameters.length; ++i) {
            String variableName = DefaultAiServices.getVariableName(parameters[i]);
            Object variableValue = args[i];
            variables.put(variableName, variableValue);
        }
        if (template.contains("{{it}}") && !variables.containsKey("it")) {
            String itValue = DefaultAiServices.getValueOfVariableIt(parameters, args);
            variables.put("it", itValue);
        }
        return variables;
    }

    private static String getVariableName(Parameter parameter) {
        V annotation = parameter.getAnnotation(V.class);
        if (annotation != null) {
            return annotation.value();
        }
        return parameter.getName();
    }

    private static String getValueOfVariableIt(Parameter[] parameters, Object[] args) {
        Parameter parameter;
        if (!(parameters.length != 1 || (parameter = parameters[0]).isAnnotationPresent(MemoryId.class) || parameter.isAnnotationPresent(UserMessage.class) || parameter.isAnnotationPresent(UserName.class) || parameter.isAnnotationPresent(V.class) && !DefaultAiServices.isAnnotatedWithIt(parameter))) {
            return DefaultAiServices.toString(args[0]);
        }
        for (int i = 0; i < parameters.length; ++i) {
            if (!DefaultAiServices.isAnnotatedWithIt(parameters[i])) continue;
            return DefaultAiServices.toString(args[i]);
        }
        throw IllegalConfigurationException.illegalConfiguration("Error: cannot find the value of the prompt template variable \"{{it}}\".");
    }

    private static boolean isAnnotatedWithIt(Parameter parameter) {
        V annotation = parameter.getAnnotation(V.class);
        return annotation != null && "it".equals(annotation.value());
    }

    private static dev.langchain4j.data.message.UserMessage prepareUserMessage(Method method, Object[] args) {
        String template = DefaultAiServices.getUserMessageTemplate(method, args);
        Map<String, Object> variables = DefaultAiServices.findTemplateVariables(template, method, args);
        Prompt prompt = PromptTemplate.from((String)template).apply(variables);
        Optional<String> maybeUserName = DefaultAiServices.findUserName(method.getParameters(), args);
        return maybeUserName.map(userName -> dev.langchain4j.data.message.UserMessage.from((String)userName, (String)prompt.text())).orElseGet(() -> ((Prompt)prompt).toUserMessage());
    }

    private static String getUserMessageTemplate(Method method, Object[] args) {
        Optional<String> templateFromMethodAnnotation = DefaultAiServices.findUserMessageTemplateFromMethodAnnotation(method);
        Optional<String> templateFromParameterAnnotation = DefaultAiServices.findUserMessageTemplateFromAnnotatedParameter(method.getParameters(), args);
        if (templateFromMethodAnnotation.isPresent() && templateFromParameterAnnotation.isPresent()) {
            throw IllegalConfigurationException.illegalConfiguration("Error: The method '%s' has multiple @UserMessage annotations. Please use only one.", method.getName());
        }
        if (templateFromMethodAnnotation.isPresent()) {
            return templateFromMethodAnnotation.get();
        }
        if (templateFromParameterAnnotation.isPresent()) {
            return templateFromParameterAnnotation.get();
        }
        Optional<String> templateFromTheOnlyArgument = DefaultAiServices.findUserMessageTemplateFromTheOnlyArgument(method.getParameters(), args);
        if (templateFromTheOnlyArgument.isPresent()) {
            return templateFromTheOnlyArgument.get();
        }
        throw IllegalConfigurationException.illegalConfiguration("Error: The method '%s' does not have a user message defined.", method.getName());
    }

    private static Optional<String> findUserMessageTemplateFromMethodAnnotation(Method method) {
        return Optional.ofNullable(method.getAnnotation(UserMessage.class)).map(a -> DefaultAiServices.getTemplate(method, "User", a.fromResource(), a.value(), a.delimiter()));
    }

    private static Optional<String> findUserMessageTemplateFromAnnotatedParameter(Parameter[] parameters, Object[] args) {
        for (int i = 0; i < parameters.length; ++i) {
            if (!parameters[i].isAnnotationPresent(UserMessage.class)) continue;
            return Optional.of(DefaultAiServices.toString(args[i]));
        }
        return Optional.empty();
    }

    private static Optional<String> findUserMessageTemplateFromTheOnlyArgument(Parameter[] parameters, Object[] args) {
        if (parameters != null && parameters.length == 1 && parameters[0].getAnnotations().length == 0) {
            return Optional.of(DefaultAiServices.toString(args[0]));
        }
        return Optional.empty();
    }

    private static Optional<String> findUserName(Parameter[] parameters, Object[] args) {
        for (int i = 0; i < parameters.length; ++i) {
            if (!parameters[i].isAnnotationPresent(UserName.class)) continue;
            return Optional.of(args[i].toString());
        }
        return Optional.empty();
    }

    private static String getTemplate(Method method, String type, String resource, String[] value, String delimiter) {
        String messageTemplate;
        if (!resource.trim().isEmpty()) {
            messageTemplate = DefaultAiServices.getResourceText(method.getDeclaringClass(), resource);
            if (messageTemplate == null) {
                throw IllegalConfigurationException.illegalConfiguration("@%sMessage's resource '%s' not found", type, resource);
            }
        } else {
            messageTemplate = String.join((CharSequence)delimiter, value);
        }
        if (messageTemplate.trim().isEmpty()) {
            throw IllegalConfigurationException.illegalConfiguration("@%sMessage's template cannot be empty", type);
        }
        return messageTemplate;
    }

    private static String getResourceText(Class<?> clazz, String resource) {
        InputStream inputStream = clazz.getResourceAsStream(resource);
        if (inputStream == null) {
            inputStream = clazz.getResourceAsStream("/" + resource);
        }
        return DefaultAiServices.getText(inputStream);
    }

    private static String getText(InputStream inputStream) {
        if (inputStream == null) {
            return null;
        }
        try (Scanner scanner = new Scanner(inputStream);){
            Scanner s = scanner.useDelimiter("\\A");
            try {
                String string;
                String string2 = string = s.hasNext() ? s.next() : "";
                if (s != null) {
                    s.close();
                }
                return string;
            }
            catch (Throwable throwable) {
                if (s != null) {
                    try {
                        s.close();
                    }
                    catch (Throwable throwable2) {
                        throwable.addSuppressed(throwable2);
                    }
                }
                throw throwable;
            }
        }
    }

    private static Optional<Object> findMemoryId(Method method, Object[] args) {
        Parameter[] parameters = method.getParameters();
        for (int i = 0; i < parameters.length; ++i) {
            if (!parameters[i].isAnnotationPresent(MemoryId.class)) continue;
            Object memoryId = args[i];
            if (memoryId == null) {
                throw Exceptions.illegalArgument((String)"The value of parameter '%s' annotated with @MemoryId in method '%s' must not be null", (Object[])new Object[]{parameters[i].getName(), method.getName()});
            }
            return Optional.of(memoryId);
        }
        return Optional.empty();
    }

    private static String toString(Object arg) {
        if (arg.getClass().isArray()) {
            return DefaultAiServices.arrayToString(arg);
        }
        if (arg.getClass().isAnnotationPresent(StructuredPrompt.class)) {
            return StructuredPromptProcessor.toPrompt((Object)arg).text();
        }
        return arg.toString();
    }

    private static String arrayToString(Object arg) {
        StringBuilder sb = new StringBuilder("[");
        int length = Array.getLength(arg);
        for (int i = 0; i < length; ++i) {
            sb.append(DefaultAiServices.toString(Array.get(arg, i)));
            if (i >= length - 1) continue;
            sb.append(", ");
        }
        sb.append("]");
        return sb.toString();
    }
}

