/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.adaptiverag;

import dev.langchain4j.adaptiverag.AnswerGrader;
import dev.langchain4j.adaptiverag.ChromaStore;
import dev.langchain4j.adaptiverag.Generation;
import dev.langchain4j.adaptiverag.HallucinationGrader;
import dev.langchain4j.adaptiverag.QuestionRewriter;
import dev.langchain4j.adaptiverag.QuestionRouter;
import dev.langchain4j.adaptiverag.RetrievalGrader;
import dev.langchain4j.adaptiverag.WebSearchTool;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import java.io.FileInputStream;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.LogManager;
import java.util.stream.Collectors;
import org.bsc.async.AsyncGenerator;
import org.bsc.langgraph4j.CompiledGraph;
import org.bsc.langgraph4j.NodeOutput;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.action.AsyncEdgeAction;
import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.state.AgentState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AdaptiveRag {
    private static final Logger log = LoggerFactory.getLogger((String)"AdaptiveRag");
    private final String openApiKey;
    private final String tavilyApiKey;
    private final AtomicReference<Object> chroma = new AtomicReference();

    public AdaptiveRag(String openApiKey, String tavilyApiKey) {
        Objects.requireNonNull(openApiKey, "no OPENAI APIKEY provided!");
        Objects.requireNonNull(tavilyApiKey, "no TAVILY APIKEY provided!");
        this.openApiKey = openApiKey;
        this.tavilyApiKey = tavilyApiKey;
    }

    private ChromaStore openChroma() {
        return ChromaStore.of(this.openApiKey);
    }

    private Map<String, Object> retrieve(State state) {
        log.debug("---RETRIEVE---");
        String question = state.question();
        EmbeddingSearchResult<TextSegment> relevant = this.getChroma().search(question);
        List documents = relevant.matches().stream().map(m -> ((TextSegment)m.embedded()).text()).collect(Collectors.toList());
        return Map.of("documents", documents, "question", question);
    }

    private Map<String, Object> generate(State state) {
        log.debug("---GENERATE---");
        String question = state.question();
        List<String> documents = state.documents();
        String generation = Generation.of(this.openApiKey).apply(question, documents);
        return Map.of("generation", generation);
    }

    private Map<String, Object> gradeDocuments(State state) {
        log.debug("---CHECK DOCUMENT RELEVANCE TO QUESTION---");
        String question = state.question();
        List<String> documents = state.documents();
        RetrievalGrader grader = RetrievalGrader.of(this.openApiKey);
        List filteredDocs = documents.stream().filter(d -> {
            RetrievalGrader.Score score = grader.apply(RetrievalGrader.Arguments.of(question, d));
            boolean relevant = score.binaryScore.equals("yes");
            if (relevant) {
                log.debug("---GRADE: DOCUMENT RELEVANT---");
            } else {
                log.debug("---GRADE: DOCUMENT NOT RELEVANT---");
            }
            return relevant;
        }).collect(Collectors.toList());
        return Map.of("documents", filteredDocs);
    }

    private Map<String, Object> transformQuery(State state) {
        log.debug("---TRANSFORM QUERY---");
        String question = state.question();
        String betterQuestion = QuestionRewriter.of(this.openApiKey).apply(question);
        return Map.of("question", betterQuestion);
    }

    private Map<String, Object> webSearch(State state) {
        log.debug("---WEB SEARCH---");
        String question = state.question();
        List<Content> result = WebSearchTool.of(this.tavilyApiKey).apply(question);
        String webResult = result.stream().map(content -> content.textSegment().text()).collect(Collectors.joining("\n"));
        return Map.of("documents", List.of(webResult));
    }

    private String routeQuestion(State state) {
        log.debug("---ROUTE QUESTION---");
        String question = state.question();
        QuestionRouter.Type source = QuestionRouter.of(this.openApiKey).apply(question);
        if (source == QuestionRouter.Type.web_search) {
            log.debug("---ROUTE QUESTION TO WEB SEARCH---");
        } else {
            log.debug("---ROUTE QUESTION TO RAG---");
        }
        return source.name();
    }

    private String decideToGenerate(State state) {
        log.debug("---ASSESS GRADED DOCUMENTS---");
        List<String> documents = state.documents();
        if (documents.isEmpty()) {
            log.debug("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---");
            return "transform_query";
        }
        log.debug("---DECISION: GENERATE---");
        return "generate";
    }

    private String gradeGeneration_v_documentsAndQuestion(State state) {
        log.debug("---CHECK HALLUCINATIONS---");
        String question = state.question();
        List<String> documents = state.documents();
        String generation = state.generation().orElseThrow(() -> new IllegalStateException("generation is not set!"));
        HallucinationGrader.Score score = HallucinationGrader.of(this.openApiKey).apply(HallucinationGrader.Arguments.of(documents, generation));
        if (Objects.equals(score.binaryScore, "yes")) {
            log.debug("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---");
            log.debug("---GRADE GENERATION vs QUESTION---");
            AnswerGrader.Score score2 = AnswerGrader.of(this.openApiKey).apply(AnswerGrader.Arguments.of(question, generation));
            if (Objects.equals(score2.binaryScore, "yes")) {
                log.debug("---DECISION: GENERATION ADDRESSES QUESTION---");
                return "useful";
            }
            log.debug("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---");
            return "not useful";
        }
        log.debug("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---");
        return "not supported";
    }

    public StateGraph<State> buildGraph() throws Exception {
        return new StateGraph(State::new).addNode("web_search", AsyncNodeAction.node_async(this::webSearch)).addNode("retrieve", AsyncNodeAction.node_async(this::retrieve)).addNode("grade_documents", AsyncNodeAction.node_async(this::gradeDocuments)).addNode("generate", AsyncNodeAction.node_async(this::generate)).addNode("transform_query", AsyncNodeAction.node_async(this::transformQuery)).addConditionalEdges(StateGraph.START, AsyncEdgeAction.edge_async(this::routeQuestion), Map.of("web_search", "web_search", "vectorstore", "retrieve")).addEdge("web_search", "generate").addEdge("retrieve", "grade_documents").addConditionalEdges("grade_documents", AsyncEdgeAction.edge_async(this::decideToGenerate), Map.of("transform_query", "transform_query", "generate", "generate")).addEdge("transform_query", "retrieve").addConditionalEdges("generate", AsyncEdgeAction.edge_async(this::gradeGeneration_v_documentsAndQuestion), Map.of("not supported", "generate", "useful", StateGraph.END, "not useful", "transform_query"));
    }

    public static void main(String[] args) throws Exception {
        try (FileInputStream configFile = new FileInputStream("logging.properties");){
            LogManager.getLogManager().readConfiguration(configFile);
        }
        AdaptiveRag adaptiveRagTest = new AdaptiveRag(System.getenv("OPENAI_API_KEY"), System.getenv("TAVILY_API_KEY"));
        CompiledGraph graph = adaptiveRagTest.buildGraph().compile();
        AsyncGenerator result = graph.stream(Map.of("question", "What player at the Bears expected to draft first in the 2024 NFL draft?"));
        String generation = "";
        for (NodeOutput r : result) {
            System.out.printf("Node: '%s':\n", r.node());
            generation = ((State)r.state()).generation().orElse("");
        }
        System.out.println(generation);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public ChromaStore getChroma() {
        Object $value = this.chroma.get();
        if ($value == null) {
            AtomicReference<Object> atomicReference = this.chroma;
            synchronized (atomicReference) {
                $value = this.chroma.get();
                if ($value == null) {
                    ChromaStore actualValue = this.openChroma();
                    $value = actualValue == null ? this.chroma : actualValue;
                    this.chroma.set($value);
                }
            }
        }
        return (ChromaStore)($value == this.chroma ? null : $value);
    }

    public static class State
    extends AgentState {
        public State(Map<String, Object> initData) {
            super(initData);
        }

        public String question() {
            Optional result = this.value("question");
            return (String)result.orElseThrow(() -> new IllegalStateException("question is not set!"));
        }

        public Optional<String> generation() {
            return this.value("generation");
        }

        public List<String> documents() {
            return this.value("documents").orElse(Collections.emptyList());
        }
    }
}

