/*
 * Copyright 2017-2025 noear.org and authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.noear.solon.flow;

import org.noear.solon.Utils;
import org.noear.solon.core.util.Assert;
import org.noear.solon.core.util.RankEntity;
import org.noear.solon.flow.intercept.ChainInterceptor;
import org.noear.solon.flow.intercept.ChainInvocation;
import org.noear.solon.flow.stateful.FlowStatefulService;
import org.noear.solon.flow.stateful.FlowStatefulServiceDefault;
import org.noear.solon.flow.driver.SimpleFlowDriver;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

/**
 * 流引擎实现
 *
 * @author noear
 * @since 3.0
 */
public class FlowEngineDefault implements FlowEngine {
    protected final Map<String, Chain> chainMap = new ConcurrentHashMap<>();
    protected final Map<String, FlowDriver> driverMap = new ConcurrentHashMap<>();
    protected final List<RankEntity<ChainInterceptor>> interceptorList = new ArrayList<>();

    public FlowEngineDefault() {
        this(null);
    }

    public FlowEngineDefault(FlowDriver driver) {
        //默认驱动器
        if (driver == null) {
            driver = new SimpleFlowDriver();
        }

        driverMap.put("", driver);
    }

    @Override
    public FlowDriver getDriver(Chain chain) {
        Assert.notNull(chain, "chain is null");

        FlowDriver driver = driverMap.get(chain.getDriver());

        if (driver == null) {
            throw new IllegalArgumentException("No driver found for: '" + chain.getDriver() + "'");
        }

        return driver;
    }

    @Override
    public <T extends FlowDriver> T getDriverAs(Chain chain, Class<T> driverClass) {
        FlowDriver driver = getDriver(chain);
        if (driverClass.isInstance(driver)) {
            return (T) driver;
        } else {
            throw new IllegalArgumentException("No " + driverClass.getSimpleName() + " found for: '" + chain.getDriver() + "'");
        }
    }

    private FlowStatefulService statefulService;

    @Override
    public FlowStatefulService statefulService() {
        if (statefulService == null) {
            statefulService = new FlowStatefulServiceDefault(this);
        }

        return statefulService;
    }

    @Override
    public void addInterceptor(ChainInterceptor interceptor, int index) {
        interceptorList.add(new RankEntity<>(interceptor, index));
        Collections.sort(interceptorList);
    }

    @Override
    public void removeInterceptor(ChainInterceptor interceptor) {
        for (RankEntity<ChainInterceptor> i : interceptorList) {
            if (i.target == interceptor) {
                interceptorList.remove(i);
                break;
            }
        }
    }

    @Override
    public void register(String name, FlowDriver driver) {
        if (driver != null) {
            driverMap.put(name, driver);
        }
    }

    @Override
    public void unregister(String name) {
        if (Utils.isNotEmpty(name)) {
            driverMap.remove(name);
        }
    }

    @Override
    public void load(Chain chain) {
        chainMap.put(chain.getId(), chain);
    }

    @Override
    public void unload(String chainId) {
        chainMap.remove(chainId);
    }

    @Override
    public Collection<Chain> getChains() {
        return chainMap.values();
    }

    @Override
    public Chain getChain(String chainId) {
        return chainMap.get(chainId);
    }

    /**
     * 评估
     *
     * @param chainId   链
     * @param exchanger 交换器
     */
    @Override
    public void eval(String chainId, String startId, int depth, FlowExchanger exchanger) throws FlowException {
        Chain chain = chainMap.get(chainId);
        if (chain == null) {
            throw new IllegalArgumentException("No chain found for id: " + chainId);
        }

        Node startNode;
        if (startId == null) {
            startNode = chain.getStart();
        } else {
            startNode = chain.getNode(startId);
        }

        eval(startNode, depth, exchanger);
    }

    /**
     * 评估
     *
     * @param startNode 开始节点
     * @param depth     执行深度
     * @param exchanger 交换器
     */
    @Override
    public void eval(Node startNode, int depth, FlowExchanger exchanger) throws FlowException {
        if (startNode == null) {
            throw new IllegalArgumentException("The start node was not found.");
        }

        //准备工作
        prepare(exchanger);

        FlowDriver driver = getDriver(startNode.getChain());

        //开始执行
        FlowExchanger bak = exchanger.context().getAs(FlowExchanger.TAG); //跨链调用时，可能会有
        try {
            if (bak != exchanger) {
                exchanger.context().put(FlowExchanger.TAG, exchanger);
            }

            new ChainInvocation(driver, exchanger, startNode, depth, this.interceptorList, this::evalDo).invoke();
        } finally {
            if (bak != exchanger) {
                if (bak == null) {
                    exchanger.context().remove(FlowExchanger.TAG);
                } else {
                    exchanger.context().put(FlowExchanger.TAG, bak);
                }
            }
        }
    }

    /**
     * 准备工作
     */
    protected void prepare(FlowExchanger exchanger) {
        if (exchanger.engine == null) {
            exchanger.engine = this;
        }
    }

    /**
     * 执行评估
     */
    protected void evalDo(ChainInvocation inv) throws FlowException {
        node_run(inv.getDriver(), inv.getExchanger(), inv.getStartNode(), inv.getEvalDepth());
    }

    /**
     * 节点运行开始时
     */
    protected void onNodeStart(FlowDriver driver, FlowExchanger exchanger, Node node) {
        for (RankEntity<ChainInterceptor> interceptor : interceptorList) {
            interceptor.target.onNodeStart(exchanger.context(), node);
        }

        driver.onNodeStart(exchanger, node);
    }

    /**
     * 节点运行结束时
     */
    protected void onNodeEnd(FlowDriver driver, FlowExchanger exchanger, Node node) {
        for (RankEntity<ChainInterceptor> interceptor : interceptorList) {
            interceptor.target.onNodeEnd(exchanger.context(), node);
        }

        driver.onNodeEnd(exchanger, node);
    }

    /**
     * 条件检测
     */
    protected boolean condition_test(FlowDriver driver, FlowExchanger exchanger, Condition condition, boolean def) throws FlowException {
        if (Utils.isNotEmpty(condition.getDescription())) {
            try {
                return driver.handleCondition(exchanger, condition);
            } catch (FlowException e) {
                throw e;
            } catch (Throwable e) {
                throw new FlowException("The test handle failed: " + condition.getChain().getId() + " / " + condition.getDescription(), e);
            }
        } else {
            return def;
        }
    }

    /**
     * 执行任务
     */
    protected void task_exec(FlowDriver driver, FlowExchanger exchanger, Node node) throws FlowException {
        //尝试检测条件；缺省为 true
        if (condition_test(driver, exchanger, node.getWhen(), true)) {
            //起到触发事件的作用 //处理方会“过滤”空任务
            try {
                driver.handleTask(exchanger, node.getTask());
            } catch (FlowException e) {
                throw e;
            } catch (Throwable e) {
                throw new FlowException("The task handle failed: " + node.getChain().getId() + " / " + node.getId(), e);
            }
        }
    }

    /**
     * 运行节点
     */
    protected boolean node_run(FlowDriver driver, FlowExchanger exchanger, Node node, int depth) throws FlowException {
        if (node == null) {
            return false;
        }

        //如果停止
        if (exchanger.isStopped()) {
            return false;
        }

        //如果阻断，当前分支不再后流
        if (exchanger.isInterrupted()) {
            //重置阻断（不影响别的分支）
            exchanger.interrupt(false);
            return false;
        }

        //执行深度控制
        if (depth == 0) {
            return true;
        } else {
            depth--;
        }

        //节点运行之前事件
        onNodeStart(driver, exchanger, node);

        //如果停止
        if (exchanger.isStopped()) {
            return false;
        }

        //如果阻断，就不再执行了（onNodeBefore 可能会触发中断）
        if (exchanger.isInterrupted()) {
            //重置阻断（不影响别的分支）
            exchanger.interrupt(false);
            return false;
        }


        boolean node_end = true;

        switch (node.getType()) {
            case START:
                //转到下个节点
                node_run(driver, exchanger, node.getNextNode(), depth);
                break;
            case END:
                break;
            case ACTIVITY:
                node_end = activity_run(driver, exchanger, node, depth);
                break;
            case INCLUSIVE: //包容网关（多选）
                node_end = inclusive_run(driver, exchanger, node, depth);
                break;
            case EXCLUSIVE: //排他网关（单选）
                exclusive_run(driver, exchanger, node, depth);
                break;
            case PARALLEL: //并行网关（全选）
                node_end = parallel_run(driver, exchanger, node, depth);
                break;
            case ITERATOR:
                node_end = iterator_run(driver, exchanger, node, depth);
                break;
        }

        //节点运行之后事件
        if (node_end) {
            onNodeEnd(driver, exchanger, node);
        }


        return node_end;
    }

    protected boolean activity_run(FlowDriver driver, FlowExchanger exchanger, Node node, int depth) {
        //流入模式
        if (node.getImode() == NodeType.PARALLEL) {
            if (parallel_run_in(driver, exchanger, node, depth) == false) {
                return false;
            }
        } else if (node.getImode() == NodeType.INCLUSIVE) {
            if (inclusive_run_in(driver, exchanger, node, depth) == false) {
                return false;
            }
        }

        //尝试执行任务（可能为空）
        task_exec(driver, exchanger, node);

        //如果停止
        if (exchanger.isStopped()) {
            return false;
        }

        //如果阻断，就不再执行了（onNodeBefore 可能会触发中断）
        if (exchanger.isInterrupted()) {
            //重置阻断（不影响别的分支）
            exchanger.interrupt(false);
            return false;
        }

        //流出模式
        if (node.getOmode() == NodeType.PARALLEL) {
            //并行网关模式
            return parallel_run_out(driver, exchanger, node, depth);
        } else if (node.getOmode() == NodeType.EXCLUSIVE) {
            //包容网关模式
            return inclusive_run_out(driver, exchanger, node, depth);
        } else {
            //默认：排它网关模式
            return exclusive_run(driver, exchanger, node, depth);
        }
    }

    /**
     * 运行包容网关
     */
    protected boolean inclusive_run(FlowDriver driver, FlowExchanger exchanger, Node node, int depth) throws FlowException {
        if (inclusive_run_in(driver, exchanger, node, depth)) {
            return inclusive_run_out(driver, exchanger, node, depth);
        } else {
            return false;
        }
    }

    protected boolean inclusive_run_in(FlowDriver driver, FlowExchanger exchanger, Node node, int depth) throws FlowException {
        Stack<Integer> inclusive_stack = exchanger.temporary().stack(node.getChain(), "inclusive_run");

        //::流入
        if (node.getPrevLinks().size() > 1) { //如果是多个输入链接（尝试等待）
            if (inclusive_stack.size() > 0) {
                int start_size = inclusive_stack.peek();
                int in_size = exchanger.temporary().countIncr(node.getChain(), node.getId());//运行次数累计
                if (start_size > in_size) { //等待所有支线流入完成
                    return false;
                }

                //聚合结束，取消这个栈节点
                inclusive_stack.pop();
            }
            //如果没有 gt 0，说明之前还没有流出的
        }

        return true;
    }

    protected boolean inclusive_run_out(FlowDriver driver, FlowExchanger exchanger, Node node, int depth) throws FlowException {
        Stack<Integer> inclusive_stack = exchanger.temporary().stack(node.getChain(), "inclusive_run");

        //::流出
        Link def_line = null;
        List<Link> matched_lines = new ArrayList<>();

        for (Link l : node.getNextLinks()) {
            if (l.getWhen().isEmpty()) {
                def_line = l;
            } else {
                if (condition_test(driver, exchanger, l.getWhen(), false)) {
                    matched_lines.add(l);
                }
            }
        }

        if (matched_lines.size() > 0) {
            //记录流出数量
            inclusive_stack.push(matched_lines.size());

            //执行所有满足条件
            for (Link l : matched_lines) {
                node_run(driver, exchanger, l.getNextNode(), depth);
            }
        } else if (def_line != null) {
            //不需要，记录流出数量
            //如果有默认
            node_run(driver, exchanger, def_line.getNextNode(), depth);
        }

        return true;
    }

    /**
     * 运行排他网关
     */
    protected boolean exclusive_run(FlowDriver driver, FlowExchanger exchanger, Node node, int depth) throws FlowException {
        //::流出
        Link def_line = null; //默认线
        for (Link l : node.getNextLinks()) {
            if (l.getWhen().isEmpty()) {
                def_line = l;
            } else {
                if (condition_test(driver, exchanger, l.getWhen(), false)) {
                    //执行第一个满足条件
                    node_run(driver, exchanger, l.getNextNode(), depth);
                    return true; //结束
                }
            }
        }

        if (def_line != null) {
            //如果有默认
            node_run(driver, exchanger, def_line.getNextNode(), depth);
        }

        return true;
    }

    /**
     * 运行并行网关
     */
    protected boolean parallel_run(FlowDriver driver, FlowExchanger exchanger, Node node, int depth) throws FlowException {
        if (parallel_run_in(driver, exchanger, node, depth)) {
            return parallel_run_out(driver, exchanger, node, depth);
        } else {
            return false;
        }
    }

    protected boolean parallel_run_in(FlowDriver driver, FlowExchanger exchanger, Node node, int depth) throws FlowException {
        //::流入
        int count = exchanger.temporary().countIncr(node.getChain(), node.getId());//运行次数累计
        if (node.getPrevLinks().size() > count) { //等待所有支线计数完成
            return false;
        }

        return true;
    }

    protected boolean parallel_run_out(FlowDriver driver, FlowExchanger exchanger, Node node, int depth) throws FlowException {
        //恢复计数
        exchanger.temporary().countSet(node.getChain(), node.getId(), 0);

        //::流出
        if (exchanger.context().executor() == null || node.getNextNodes().size() < 2) { //没有2个，也没必要用线程池
            //单线程
            for (Node n : node.getNextNodes()) {
                node_run(driver, exchanger, n, depth);
            }
        } else {
            //多线程
            CountDownLatch cdl = new CountDownLatch(node.getNextNodes().size());
            AtomicReference<Throwable> errorRef = new AtomicReference<>();
            for (Node n : node.getNextNodes()) {
                exchanger.context().executor().execute(() -> {
                    try {
                        if (errorRef.get() != null) {
                            return;
                        }

                        node_run(driver, exchanger, n, depth);
                    } catch (Throwable ex) {
                        errorRef.set(ex);
                    } finally {
                        cdl.countDown();
                    }
                });
            }

            //等待
            try {
                cdl.await();
            } catch (InterruptedException ignore) {
                //
            }

            //异常处理
            if (errorRef.get() != null) {
                if (errorRef.get() instanceof FlowException) {
                    throw (FlowException) errorRef.get();
                } else {
                    throw new FlowException(errorRef.get());
                }
            }
        }

        return true;
    }

    protected boolean iterator_run(FlowDriver driver, FlowExchanger exchanger, Node node, int depth) {
        if (node.getMeta("$for") == null) {
            //结束
            if (iterator_run_in(driver, exchanger, node, depth)) {
                return node_run(driver, exchanger, node.getNextNode(), depth);
            } else {
                return false;
            }
        } else {
            //开始
            return iterator_run_out(driver, exchanger, node, depth);
        }
    }

    protected boolean iterator_run_in(FlowDriver driver, FlowExchanger exchanger, Node node, int depth) {
        Stack<Iterator> iterator_stack = exchanger.temporary().stack(node.getChain(), "iterator_run");

        //::流入
        if (iterator_stack.size() > 0) {
            Iterator inIterator = iterator_stack.peek();
            if (inIterator.hasNext()) { //等待遍历完成
                return false;
            }

            //聚合结束，取消这个栈节点
            iterator_stack.pop();
        }
        //如果没有 gt 0，说明之前还没有流出的

        return true;
    }

    protected boolean iterator_run_out(FlowDriver driver, FlowExchanger exchanger, Node node, int depth) {
        String forKey = node.getMeta("$for");
        String inKey = node.getMeta("$in");
        Object inObj = exchanger.context().getAs(inKey);

        Iterator inIterator = null;
        if (inObj instanceof Iterator) {
            inIterator = (Iterator) inObj;
        } else if (inObj instanceof Iterable) {
            inIterator = ((Iterable) inObj).iterator();
        } else {
            throw new FlowException(inKey + " is not a Iterable");
        }


        Stack<Iterator> iterator_stack = exchanger.temporary().stack(node.getChain(), "iterator_run");
        iterator_stack.push(inIterator);

        //::流出
        while (inIterator.hasNext()) {
            Object item = inIterator.next();
            exchanger.context().put(forKey, item);
            node_run(driver, exchanger, node.getNextNode(), depth);
        }

        return true;
    }
}