/*
 * Decompiled with CFR 0.152.
 */
package org.xxdc.oss.example.bot;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.xxdc.oss.example.GameState;
import org.xxdc.oss.example.bot.BotStrategy;
import org.xxdc.oss.example.bot.BotStrategyConfig;

public final class MonteCarloTreeSearch
implements BotStrategy {
    private static final System.Logger log = System.getLogger(MonteCarloTreeSearch.class.getName());
    private final GameState initialState;
    private final BotStrategyConfig config;
    private static final double MIN_SCORE = -0.5;
    private static final double MAX_SCORE = 1.0;
    private static final double DRAW_SCORE = 0.0;

    public MonteCarloTreeSearch(GameState state) {
        this(state, BotStrategyConfig.newBuilder().maxTimeMillis(TimeUnit.SECONDS, 1L).build());
    }

    public MonteCarloTreeSearch(GameState state, BotStrategyConfig config) {
        this.initialState = state;
        this.config = config;
    }

    @Override
    public int bestMove() {
        return this.monteCarloTreeSearch(this.initialState);
    }

    private int monteCarloTreeSearch(GameState state) {
        MCTSNode root = new MCTSNode(state);
        long startTime = System.currentTimeMillis();
        int iterations = 0;
        while (!this.config.exceedsMaxTimeMillis(System.currentTimeMillis() - startTime) && !this.config.exceedsMaxIterations(iterations++)) {
            MCTSNode node = this.treePolicy(root);
            double[] reward = this.defaultPolicy(node.state);
            this.backpropagate(node, reward);
        }
        if (log.isLoggable(System.Logger.Level.DEBUG)) {
            log.log(System.Logger.Level.DEBUG, "MCTS: \n" + String.valueOf(root));
            log.log(System.Logger.Level.DEBUG, "MCTS (Selected): \n" + this.bestChild((MCTSNode)root).state.lastMove());
        }
        return this.bestChild((MCTSNode)root).state.lastMove();
    }

    private MCTSNode treePolicy(MCTSNode node) {
        while (!node.state.isTerminal()) {
            if (!node.isFullyExpanded()) {
                return this.expand(node);
            }
            node = node.select();
        }
        return node;
    }

    private MCTSNode expand(MCTSNode node) {
        ArrayList<Integer> untriedMoves = new ArrayList<Integer>(node.state.board().availableMoves());
        untriedMoves.removeAll(node.children.stream().map(child -> child.state.lastMove()).collect(Collectors.toList()));
        int move = untriedMoves.get(new Random().nextInt(untriedMoves.size()));
        GameState newState = node.state.afterPlayerMoves(move);
        MCTSNode child2 = new MCTSNode(newState, node);
        node.children.add(child2);
        return child2;
    }

    private double[] defaultPolicy(GameState state) {
        GameState tempState = new GameState(state);
        while (!tempState.isTerminal()) {
            List<Integer> moves = tempState.board().availableMoves();
            int move = moves.get(new Random().nextInt(moves.size()));
            tempState = tempState.afterPlayerMoves(move);
        }
        return this.defaultReward(tempState);
    }

    private double[] defaultReward(GameState state) {
        int i;
        double[] reward = new double[state.playerMarkers().size()];
        int winningPlayerIndex = -1;
        for (i = 0; i < state.playerMarkers().size(); ++i) {
            if (!state.board().hasChain(state.playerMarkers().get(i))) continue;
            winningPlayerIndex = i;
            break;
        }
        for (i = 0; i < state.playerMarkers().size(); ++i) {
            reward[i] = i == winningPlayerIndex ? 1.0 : (winningPlayerIndex != -1 ? -0.5 : 0.0);
        }
        return reward;
    }

    private void backpropagate(MCTSNode node, double[] reward) {
        while (node != null) {
            ++node.visits;
            for (int i = 0; i < this.initialState.playerMarkers().size(); ++i) {
                int n = i;
                node.scores[n] = node.scores[n] + reward[i];
            }
            node = node.parent;
        }
    }

    private MCTSNode bestChild(MCTSNode node) {
        return node.children.stream().max(Comparator.comparingDouble(c -> c.visits)).orElseThrow();
    }

    static class MCTSNode {
        GameState state;
        MCTSNode parent;
        List<MCTSNode> children;
        int visits;
        double[] scores;

        public MCTSNode(GameState state) {
            this(state, null);
        }

        public MCTSNode(GameState state, MCTSNode parent) {
            this.state = state;
            this.parent = parent;
            this.children = new ArrayList<MCTSNode>();
            this.visits = 0;
            this.scores = new double[state.playerMarkers().size()];
        }

        public MCTSNode select() {
            MCTSNode selected = null;
            double bestValue = Double.NEGATIVE_INFINITY;
            for (MCTSNode child : this.children) {
                double uctValue = child.scores[this.state.currentPlayerIndex()] / (double)child.visits + Math.sqrt(2.0 * Math.log(this.visits) / (double)child.visits);
                if (!(uctValue > bestValue)) continue;
                selected = child;
                bestValue = uctValue;
            }
            return selected;
        }

        public boolean isFullyExpanded() {
            return this.children.size() == this.state.board().availableMoves().size();
        }

        public String toString() {
            return this.toString(0);
        }

        String toString(int depth) {
            StringBuilder builder = new StringBuilder();
            builder.append(" ".repeat(depth * 2));
            builder.append((String)(this.parent == null ? "Root" : this.state.playerMarkers().get(this.state.lastPlayerIndex()) + " -> " + this.state.lastMove()));
            builder.append(" (");
            builder.append(this.visits);
            builder.append(") => ");
            builder.append(this.parent != null && this.state.lastPlayerHasChain() ? "WINNER" : this.state.availableMoves());
            builder.append("\n");
            builder.append(" ".repeat(depth * 2));
            builder.append(" (");
            for (int i = 0; i < this.scores.length; ++i) {
                builder.append(this.state.playerMarkers().get(i));
                builder.append(": ");
                builder.append(this.scores[i]);
                builder.append(i < this.scores.length - 1 ? ", " : "");
            }
            builder.append(")");
            for (MCTSNode child : this.children) {
                builder.append("\n");
                builder.append(child.toString(depth + 1));
            }
            return builder.toString();
        }
    }
}

