/*
 * Decompiled with CFR 0.152.
 */
package org.bsc.langgraph4j;

import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import lombok.NonNull;
import org.bsc.langgraph4j.CompileConfig;
import org.bsc.langgraph4j.CompiledGraph;
import org.bsc.langgraph4j.Edge;
import org.bsc.langgraph4j.EdgeCondition;
import org.bsc.langgraph4j.EdgeValue;
import org.bsc.langgraph4j.GraphRepresentation;
import org.bsc.langgraph4j.GraphRunnerException;
import org.bsc.langgraph4j.GraphStateException;
import org.bsc.langgraph4j.Node;
import org.bsc.langgraph4j.action.AsyncEdgeAction;
import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.serializer.StateSerializer;
import org.bsc.langgraph4j.serializer.std.ObjectStreamStateSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AgentStateFactory;
import org.bsc.langgraph4j.state.Channel;
import org.bsc.langgraph4j.utils.CollectionsUtils;

public class StateGraph<State extends AgentState> {
    public static String END = "__END__";
    public static String START = "__START__";
    Set<Node<State>> nodes = new LinkedHashSet<Node<State>>();
    Set<Edge<State>> edges = new LinkedHashSet<Edge<State>>();
    private EdgeValue<State> entryPoint;
    private String finishPoint;
    private final Map<String, Channel<?>> channels;
    private final StateSerializer<State> stateSerializer;

    public StateGraph(Map<String, Channel<?>> channels, StateSerializer<State> stateSerializer) {
        this.channels = channels;
        this.stateSerializer = stateSerializer;
    }

    public StateGraph(@NonNull StateSerializer<State> stateSerializer) {
        this(CollectionsUtils.mapOf(), stateSerializer);
        if (stateSerializer == null) {
            throw new NullPointerException("stateSerializer is marked non-null but is null");
        }
    }

    public StateGraph(AgentStateFactory<State> stateFactory) {
        this(CollectionsUtils.mapOf(), stateFactory);
    }

    public StateGraph(Map<String, Channel<?>> channels, AgentStateFactory<State> stateFactory) {
        this(channels, new ObjectStreamStateSerializer<State>(stateFactory));
    }

    public final AgentStateFactory<State> getStateFactory() {
        return this.stateSerializer.stateFactory();
    }

    public Map<String, Channel<?>> getChannels() {
        return Collections.unmodifiableMap(this.channels);
    }

    @Deprecated
    public EdgeValue<State> getEntryPoint() {
        return this.entryPoint;
    }

    @Deprecated
    public String getFinishPoint() {
        return this.finishPoint;
    }

    @Deprecated
    public void setEntryPoint(String entryPoint) {
        this.entryPoint = new EdgeValue(entryPoint, null);
    }

    @Deprecated
    public void setConditionalEntryPoint(AsyncEdgeAction<State> condition, Map<String, String> mappings) throws GraphStateException {
        this.addConditionalEdges(START, condition, mappings);
    }

    @Deprecated
    public void setFinishPoint(String finishPoint) {
        this.finishPoint = finishPoint;
    }

    public StateGraph<State> addNode(String id, AsyncNodeAction<State> action) throws GraphStateException {
        if (Objects.equals(id, END)) {
            throw Errors.invalidNodeIdentifier.exception(END);
        }
        Node<State> node = new Node<State>(id, action);
        if (this.nodes.contains(node)) {
            throw Errors.duplicateNodeError.exception(id);
        }
        this.nodes.add(node);
        return this;
    }

    public StateGraph<State> addEdge(String sourceId, String targetId) throws GraphStateException {
        if (Objects.equals(sourceId, END)) {
            throw Errors.invalidEdgeIdentifier.exception(END);
        }
        if (Objects.equals(sourceId, START)) {
            this.entryPoint = new EdgeValue(targetId, null);
            return this;
        }
        Edge edge = new Edge(sourceId, new EdgeValue(targetId, null));
        if (this.edges.contains(edge)) {
            throw Errors.duplicateEdgeError.exception(sourceId);
        }
        this.edges.add(edge);
        return this;
    }

    public StateGraph<State> addConditionalEdges(String sourceId, AsyncEdgeAction<State> condition, Map<String, String> mappings) throws GraphStateException {
        if (Objects.equals(sourceId, END)) {
            throw Errors.invalidEdgeIdentifier.exception(END);
        }
        if (mappings == null || mappings.isEmpty()) {
            throw Errors.edgeMappingIsEmpty.exception(sourceId);
        }
        if (Objects.equals(sourceId, START)) {
            this.entryPoint = new EdgeValue<State>(null, new EdgeCondition<State>(condition, mappings));
            return this;
        }
        Edge<State> edge = new Edge<State>(sourceId, new EdgeValue<State>(null, new EdgeCondition<State>(condition, mappings)));
        if (this.edges.contains(edge)) {
            throw Errors.duplicateEdgeError.exception(sourceId);
        }
        this.edges.add(edge);
        return this;
    }

    private Node<State> nodeById(String id) {
        return new Node(id, null);
    }

    public CompiledGraph<State> compile(CompileConfig config) throws GraphStateException {
        Objects.requireNonNull(config, "config cannot be null");
        if (this.entryPoint == null) {
            throw Errors.missingEntryPoint.exception(new String[0]);
        }
        if (this.entryPoint.id() != null && !this.nodes.contains(this.nodeById(this.entryPoint.id()))) {
            throw Errors.entryPointNotExist.exception(this.entryPoint.id());
        }
        if (this.finishPoint != null && !this.nodes.contains(this.nodeById(this.finishPoint))) {
            throw Errors.finishPointNotExist.exception(this.finishPoint);
        }
        for (Edge<State> edge : this.edges) {
            if (!this.nodes.contains(this.nodeById(edge.sourceId()))) {
                throw Errors.missingNodeReferencedByEdge.exception(edge.sourceId());
            }
            if (edge.target().id() != null) {
                if (Objects.equals(edge.target().id(), END) || this.nodes.contains(this.nodeById(edge.target().id()))) continue;
                throw Errors.missingNodeReferencedByEdge.exception(edge.target().id());
            }
            if (edge.target().value() != null) {
                for (String nodeId : edge.target().value().mappings().values()) {
                    if (Objects.equals(nodeId, END) || this.nodes.contains(this.nodeById(nodeId))) continue;
                    throw Errors.missingNodeInEdgeMapping.exception(edge.sourceId(), nodeId);
                }
                continue;
            }
            throw Errors.invalidEdgeTarget.exception(edge.sourceId());
        }
        return new CompiledGraph(this, config);
    }

    public CompiledGraph<State> compile() throws GraphStateException {
        return this.compile(CompileConfig.builder().build());
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges) {
        String content = type.generator.generate(this, title, printConditionalEdges);
        return new GraphRepresentation(type, content);
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type, String title) {
        String content = type.generator.generate(this, title, true);
        return new GraphRepresentation(type, content);
    }

    public StateSerializer<State> getStateSerializer() {
        return this.stateSerializer;
    }

    static enum Errors {
        invalidNodeIdentifier("END is not a valid node id!"),
        invalidEdgeIdentifier("END is not a valid edge sourceId!"),
        duplicateNodeError("node with id: %s already exist!"),
        duplicateEdgeError("edge with id: %s already exist!"),
        edgeMappingIsEmpty("edge mapping is empty!"),
        missingEntryPoint("missing Entry Point"),
        entryPointNotExist("entryPoint: %s doesn't exist!"),
        finishPointNotExist("finishPoint: %s doesn't exist!"),
        missingNodeReferencedByEdge("edge sourceId: %s reference a not existent node!"),
        missingNodeInEdgeMapping("edge mapping for sourceId: %s contains a not existent nodeId %s!"),
        invalidEdgeTarget("edge sourceId: %s has an initialized target value!");

        private final String errorMessage;

        private Errors(String errorMessage) {
            this.errorMessage = errorMessage;
        }

        GraphStateException exception(String ... args) {
            return new GraphStateException(String.format(this.errorMessage, args));
        }
    }

    static enum RunnableErrors {
        missingNodeInEdgeMapping("cannot find edge mapping for id: %s in conditional edge with sourceId: %s "),
        missingNode("node with id: %s doesn't exist!"),
        missingEdge("edge with sourceId: %s doesn't exist!"),
        executionError("%s");

        private final String errorMessage;

        private RunnableErrors(String errorMessage) {
            this.errorMessage = errorMessage;
        }

        GraphRunnerException exception(String ... args) {
            return new GraphRunnerException(String.format(this.errorMessage, args));
        }
    }
}

