/*
 * Decompiled with CFR 0.152.
 */
package org.bsc.langgraph4j;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Spliterators;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import lombok.NonNull;
import org.bsc.async.AsyncGenerator;
import org.bsc.langgraph4j.CompileConfig;
import org.bsc.langgraph4j.EdgeValue;
import org.bsc.langgraph4j.GraphRepresentation;
import org.bsc.langgraph4j.NodeOutput;
import org.bsc.langgraph4j.RunnableConfig;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.action.AsyncEdgeAction;
import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;
import org.bsc.langgraph4j.checkpoint.Checkpoint;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.Channel;
import org.bsc.langgraph4j.state.StateSnapshot;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CompiledGraph<State extends AgentState> {
    private static final Logger log = LoggerFactory.getLogger(CompiledGraph.class);
    final StateGraph<State> stateGraph;
    final Map<String, AsyncNodeAction<State>> nodes = new LinkedHashMap<String, AsyncNodeAction<State>>();
    final Map<String, EdgeValue<State>> edges = new LinkedHashMap<String, EdgeValue<State>>();
    private int maxIterations = 25;
    private final CompileConfig compileConfig;

    protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfig) {
        this.stateGraph = stateGraph;
        this.compileConfig = compileConfig;
        stateGraph.nodes.forEach(n -> this.nodes.put(n.id(), n.action()));
        stateGraph.edges.forEach(e -> this.edges.put(e.sourceId(), e.target()));
    }

    public Collection<StateSnapshot<State>> getStateHistory(RunnableConfig config) {
        BaseCheckpointSaver saver = this.compileConfig.checkpointSaver().orElseThrow(() -> new IllegalStateException("Missing CheckpointSaver!"));
        return saver.list(config).stream().map(checkpoint -> StateSnapshot.of(checkpoint, config, this.stateGraph.getStateFactory())).collect(Collectors.toList());
    }

    public StateSnapshot<State> getState(RunnableConfig config) {
        return this.stateOf(config).orElseThrow(() -> new IllegalStateException("Missing Checkpoint!"));
    }

    public Optional<StateSnapshot<State>> stateOf(RunnableConfig config) {
        BaseCheckpointSaver saver = this.compileConfig.checkpointSaver().orElseThrow(() -> new IllegalStateException("Missing CheckpointSaver!"));
        return saver.get(config).map(checkpoint -> StateSnapshot.of(checkpoint, config, this.stateGraph.getStateFactory()));
    }

    public RunnableConfig updateState(RunnableConfig config, Map<String, Object> values, String asNode) throws Exception {
        BaseCheckpointSaver saver = this.compileConfig.checkpointSaver().orElseThrow(() -> new IllegalStateException("Missing CheckpointSaver!"));
        Checkpoint branchCheckpoint = saver.get(config).map(Checkpoint::new).map(cp -> cp.updateState(values, this.stateGraph.getChannels())).orElseThrow(() -> new IllegalStateException("Missing Checkpoint!"));
        String nextNodeId = null;
        if (asNode != null) {
            nextNodeId = this.nextNodeId(asNode, branchCheckpoint.getState());
        }
        RunnableConfig newConfig = saver.put(config, branchCheckpoint);
        return RunnableConfig.builder(newConfig).checkPointId(branchCheckpoint.getId()).nextNode(nextNodeId).build();
    }

    public RunnableConfig updateState(RunnableConfig config, Map<String, Object> values) throws Exception {
        return this.updateState(config, values, null);
    }

    public EdgeValue<State> getEntryPoint() {
        return this.stateGraph.getEntryPoint();
    }

    public String getFinishPoint() {
        return this.stateGraph.getFinishPoint();
    }

    public void setMaxIterations(int maxIterations) {
        if (maxIterations <= 0) {
            throw new IllegalArgumentException("maxIterations must be > 0!");
        }
        this.maxIterations = maxIterations;
    }

    private String nextNodeId(EdgeValue<State> route, Map<String, Object> state, String nodeId) throws Exception {
        if (route == null) {
            throw StateGraph.RunnableErrors.missingEdge.exception(nodeId);
        }
        if (route.id() != null) {
            return route.id();
        }
        if (route.value() != null) {
            AgentState derefState = (AgentState)this.stateGraph.getStateFactory().apply(state);
            AsyncEdgeAction<State> condition = route.value().action();
            String newRoute = condition.apply((State)derefState).get();
            String result = route.value().mappings().get(newRoute);
            if (result == null) {
                throw StateGraph.RunnableErrors.missingNodeInEdgeMapping.exception(nodeId, newRoute);
            }
            return result;
        }
        throw StateGraph.RunnableErrors.executionError.exception(String.format("invalid edge value for nodeId: [%s] !", nodeId));
    }

    private String nextNodeId(String nodeId, Map<String, Object> state) throws Exception {
        return this.nextNodeId(this.edges.get(nodeId), state, nodeId);
    }

    private String getEntryPoint(Map<String, Object> state) throws Exception {
        return this.nextNodeId(this.stateGraph.getEntryPoint(), state, "entryPoint");
    }

    private boolean shouldInterruptBefore(@NonNull String nodeId, String previousNodeId) {
        if (nodeId == null) {
            throw new NullPointerException("nodeId is marked non-null but is null");
        }
        if (previousNodeId == null) {
            return false;
        }
        return Arrays.asList(this.compileConfig.getInterruptBefore()).contains(nodeId);
    }

    private boolean shouldInterruptAfter(String nodeId, String previousNodeId) {
        if (nodeId == null) {
            return false;
        }
        return Arrays.asList(this.compileConfig.getInterruptAfter()).contains(nodeId);
    }

    private Optional<Checkpoint> addCheckpoint(RunnableConfig config, String nodeId, Map<String, Object> state, String nextNodeId) throws Exception {
        if (this.compileConfig.checkpointSaver().isPresent()) {
            Checkpoint cp = Checkpoint.builder().nodeId(nodeId).state((AgentState)this.cloneState(state)).nextNodeId(nextNodeId).build();
            this.compileConfig.checkpointSaver().get().put(config, cp);
            return Optional.of(cp);
        }
        return Optional.empty();
    }

    Map<String, Object> getInitialStateFromSchema() {
        return this.stateGraph.getChannels().entrySet().stream().filter(c -> ((Channel)c.getValue()).getDefault().isPresent()).collect(Collectors.toMap(Map.Entry::getKey, e -> ((Channel)e.getValue()).getDefault().get().get()));
    }

    Map<String, Object> getInitialState(Map<String, Object> inputs, RunnableConfig config) {
        return this.compileConfig.checkpointSaver().flatMap(saver -> saver.get(config)).map(cp -> AgentState.updateState(cp.getState(), inputs, this.stateGraph.getChannels())).orElseGet(() -> AgentState.updateState(this.getInitialStateFromSchema(), inputs, this.stateGraph.getChannels()));
    }

    State cloneState(Map<String, Object> data) throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException {
        return this.stateGraph.getStateSerializer().cloneObject(data);
    }

    public AsyncGenerator<NodeOutput<State>> stream(Map<String, Object> inputs, RunnableConfig config) throws Exception {
        Objects.requireNonNull(config, "config cannot be null");
        AsyncNodeGenerator generator = new AsyncNodeGenerator(inputs, config);
        return new AsyncGenerator.WithEmbed(generator);
    }

    public AsyncGenerator<NodeOutput<State>> stream(Map<String, Object> inputs) throws Exception {
        return this.stream(inputs, RunnableConfig.builder().build());
    }

    public Optional<State> invoke(Map<String, Object> inputs, RunnableConfig config) throws Exception {
        Iterator sourceIterator = this.stream(inputs, config).iterator();
        Stream result = StreamSupport.stream(Spliterators.spliteratorUnknownSize(sourceIterator, 16), false);
        return result.reduce((a, b) -> b).map(NodeOutput::state);
    }

    public Optional<State> invoke(Map<String, Object> inputs) throws Exception {
        return this.invoke(inputs, RunnableConfig.builder().build());
    }

    public AsyncGenerator<NodeOutput<State>> streamSnapshots(Map<String, Object> inputs, RunnableConfig config) throws Exception {
        Objects.requireNonNull(config, "config cannot be null");
        AsyncNodeGenerator generator = new AsyncNodeGenerator(inputs, config.withStreamMode(StreamMode.SNAPSHOTS));
        return new AsyncGenerator.WithEmbed(generator);
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges) {
        String content = type.generator.generate(this.stateGraph, title, printConditionalEdges);
        return new GraphRepresentation(type, content);
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type, String title) {
        String content = type.generator.generate(this.stateGraph, title, true);
        return new GraphRepresentation(type, content);
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type) {
        return this.getGraph(type, "Graph Diagram", true);
    }

    public Map<String, AsyncNodeAction<State>> getNodes() {
        return this.nodes;
    }

    public Map<String, EdgeValue<State>> getEdges() {
        return this.edges;
    }

    public class AsyncNodeGenerator<Output extends NodeOutput<State>>
    implements AsyncGenerator<Output> {
        Map<String, Object> currentState;
        String currentNodeId;
        String nextNodeId;
        int iteration = 0;
        RunnableConfig config;
        boolean resumedFromEmbed = false;

        protected AsyncNodeGenerator(Map<String, Object> inputs, RunnableConfig config) throws Exception {
            boolean isResumeRequest;
            boolean bl = isResumeRequest = inputs == null;
            if (isResumeRequest) {
                log.trace("RESUME REQUEST");
                BaseCheckpointSaver saver = CompiledGraph.this.compileConfig.checkpointSaver().orElseThrow(() -> new IllegalStateException("inputs cannot be null (ie. resume request) if no checkpoint saver is configured"));
                Checkpoint startCheckpoint = saver.get(config).orElseThrow(() -> new IllegalStateException("Resume request without a saved checkpoint!"));
                this.currentState = startCheckpoint.getState();
                this.config = config.withCheckPointId(null);
                this.nextNodeId = startCheckpoint.getNextNodeId();
                this.currentNodeId = null;
                log.trace("RESUME FROM {}", (Object)startCheckpoint.getNodeId());
            } else {
                log.trace("START");
                Map<String, Object> initState = CompiledGraph.this.getInitialState(inputs, config);
                AgentState initializedState = (AgentState)CompiledGraph.this.stateGraph.getStateFactory().apply(initState);
                this.currentState = initializedState.data();
                this.nextNodeId = null;
                this.currentNodeId = StateGraph.START;
                this.config = config;
            }
        }

        protected Output buildNodeOutput(String nodeId) throws Exception {
            return (Output)NodeOutput.of(nodeId, CompiledGraph.this.cloneState(this.currentState));
        }

        protected Output buildStateSnapshot(Checkpoint checkpoint) throws Exception {
            return (Output)StateSnapshot.of(checkpoint, this.config, CompiledGraph.this.stateGraph.getStateFactory());
        }

        private Optional<AsyncGenerator.Data<Output>> getEmbedGenerator(Map<String, Object> partialState) {
            return partialState.entrySet().stream().filter(e -> e.getValue() instanceof AsyncGenerator).findFirst().map(e -> AsyncGenerator.Data.composeWith((AsyncGenerator)((AsyncGenerator)e.getValue()), data -> {
                if (!(data instanceof Map)) {
                    throw new IllegalArgumentException("Embedded generator must return a Map");
                }
                this.currentState = AgentState.updateState(this.currentState, (Map<String, Object>)((Map)data), CompiledGraph.this.stateGraph.getChannels());
                this.nextNodeId = CompiledGraph.this.nextNodeId(this.currentNodeId, this.currentState);
                this.resumedFromEmbed = true;
            }));
        }

        private CompletableFuture<AsyncGenerator.Data<Output>> evaluateAction(AsyncNodeAction<State> action, State withState) {
            return action.apply(withState).thenApply(partialState -> {
                try {
                    Optional<AsyncGenerator.Data<Output>> embed = this.getEmbedGenerator((Map<String, Object>)partialState);
                    if (embed.isPresent()) {
                        return embed.get();
                    }
                    this.currentState = AgentState.updateState(this.currentState, (Map<String, Object>)partialState, CompiledGraph.this.stateGraph.getChannels());
                    this.nextNodeId = CompiledGraph.this.nextNodeId(this.currentNodeId, this.currentState);
                    return AsyncGenerator.Data.of(this.getNodeOutput());
                }
                catch (Exception e) {
                    throw new CompletionException(e);
                }
            });
        }

        private CompletableFuture<Output> getNodeOutput() throws Exception {
            Optional cp = CompiledGraph.this.addCheckpoint(this.config, this.currentNodeId, this.currentState, this.nextNodeId);
            return CompletableFuture.completedFuture(cp.isPresent() && this.config.streamMode() == StreamMode.SNAPSHOTS ? this.buildStateSnapshot((Checkpoint)cp.get()) : this.buildNodeOutput(this.currentNodeId));
        }

        public AsyncGenerator.Data<Output> next() {
            if (++this.iteration > CompiledGraph.this.maxIterations) {
                log.warn("Maximum number of iterations ({}) reached!", (Object)CompiledGraph.this.maxIterations);
                return AsyncGenerator.Data.done();
            }
            if (this.nextNodeId == null && this.currentNodeId == null) {
                return AsyncGenerator.Data.done();
            }
            try {
                if (this.resumedFromEmbed) {
                    CompletableFuture<Output> future = this.getNodeOutput();
                    this.resumedFromEmbed = false;
                    return AsyncGenerator.Data.of(future);
                }
                if (StateGraph.START.equals(this.currentNodeId)) {
                    this.currentNodeId = this.nextNodeId = CompiledGraph.this.getEntryPoint(this.currentState);
                    CompiledGraph.this.addCheckpoint(this.config, StateGraph.START, this.currentState, this.nextNodeId);
                    return AsyncGenerator.Data.of(this.buildNodeOutput(StateGraph.START));
                }
                if (StateGraph.END.equals(this.nextNodeId)) {
                    this.nextNodeId = null;
                    this.currentNodeId = null;
                    return AsyncGenerator.Data.of(this.buildNodeOutput(StateGraph.END));
                }
                if (CompiledGraph.this.shouldInterruptAfter(this.currentNodeId, this.nextNodeId)) {
                    return AsyncGenerator.Data.done();
                }
                if (CompiledGraph.this.shouldInterruptBefore(this.nextNodeId, this.currentNodeId)) {
                    return AsyncGenerator.Data.done();
                }
                this.currentNodeId = this.nextNodeId;
                AsyncNodeAction action = CompiledGraph.this.nodes.get(this.currentNodeId);
                if (action == null) {
                    throw StateGraph.RunnableErrors.missingNode.exception(this.currentNodeId);
                }
                return this.evaluateAction(action, CompiledGraph.this.cloneState(this.currentState)).get();
            }
            catch (Exception e) {
                log.error(e.getMessage(), (Throwable)e);
                return AsyncGenerator.Data.error((Throwable)e);
            }
        }
    }

    public static enum StreamMode {
        VALUES,
        SNAPSHOTS;

    }
}

