/*
 * Decompiled with CFR 0.152.
 */
package network.aika.neuron.activation.search;

import java.io.Serializable;
import java.util.Map;
import java.util.TreeMap;
import network.aika.Document;
import network.aika.Utils;
import network.aika.neuron.Synapse;
import network.aika.neuron.activation.Activation;
import network.aika.neuron.activation.search.Branch;
import network.aika.neuron.activation.search.Decision;
import network.aika.neuron.activation.search.Option;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SearchNode
implements Comparable<SearchNode> {
    private static final Logger log = LoggerFactory.getLogger(SearchNode.class);
    public static int MAX_SEARCH_STEPS = Integer.MAX_VALUE;
    public static boolean OPTIMIZE_SEARCH = true;
    public static boolean COMPUTE_SOFT_MAX = false;
    private int id;
    private SearchNode parent;
    private Decision decision;
    private Activation act;
    private int level;
    public Branch selected = new Branch();
    public Branch excluded = new Branch();
    private double weightDelta;
    private double accumulatedWeight = 0.0;
    private Map<Activation, Option> modifiedActs = new TreeMap<Activation, Option>(Activation.ACTIVATION_ID_COMP);
    private Step step = Step.INIT;
    private Decision currentChildDecision = Decision.UNKNOWN;
    private long processVisited;
    private boolean bestPath;
    private int cachedCount = 1;
    private int cachedFactor = 1;
    private DebugState debugState;

    public SearchNode(Document doc, Decision d, SearchNode p, int level) {
        this.id = doc.searchNodeIdCounter++;
        this.decision = d;
        this.parent = p;
        this.level = level;
    }

    public Branch getBranch(Decision d) {
        switch (d) {
            case SELECTED: {
                return this.selected;
            }
            case EXCLUDED: {
                return this.excluded;
            }
        }
        return null;
    }

    public SearchNode getAlternative() {
        return this.parent.getBranch((Decision)this.decision.getInverted()).child;
    }

    public void updateActivations(Document doc) throws Activation.OscillatingActivationsException {
        Activation act = this.getActivation();
        this.weightDelta = doc.getValueQueue().process(this);
        if (act != null && this.followPath()) {
            act.cachedSearchNode = this;
        }
        if (this.parent != null) {
            this.accumulatedWeight = this.weightDelta + this.parent.accumulatedWeight;
        }
    }

    public boolean followPath() {
        return this.getActivation().currentOption.searchNode == this && this.decision == this.getActivation().currentOption.getState().getPreferredDecision();
    }

    public int getId() {
        return this.id;
    }

    public Map<Activation, Option> getModifiedActivations() {
        return this.modifiedActs;
    }

    public double getAccumulatedWeight() {
        return this.accumulatedWeight;
    }

    public Activation getActivation() {
        return this.parent != null ? this.parent.act : null;
    }

    public static void search(Document doc, SearchNode root, long v, Long timeoutInMilliSeconds) throws TimeoutException, Activation.RecursiveDepthExceededException, Activation.OscillatingActivationsException {
        SearchNode sn = root;
        double returnWeight = 0.0;
        double returnWeightSum = 0.0;
        long startTime = System.currentTimeMillis();
        do {
            if (sn.processVisited != v) {
                sn.step = Step.INIT;
                sn.processVisited = v;
            }
            switch (sn.step) {
                case INIT: {
                    if (sn.level >= doc.candidates.size()) {
                        SearchNode.checkTimeoutCondition(timeoutInMilliSeconds, startTime);
                        returnWeightSum = returnWeight = sn.processResult(doc);
                        sn.step = Step.FINAL;
                        sn = sn.parent;
                        break;
                    }
                    sn.initStep(doc);
                    sn.step = Step.SELECT;
                    break;
                }
                case SELECT: {
                    if (sn.prepareStep(doc, Decision.SELECTED)) {
                        sn.step = Step.POST_SELECT;
                        sn = sn.selected.child;
                        break;
                    }
                    sn.step = Step.EXCLUDE;
                    break;
                }
                case POST_SELECT: {
                    sn.selected.postStep(returnWeight, returnWeightSum);
                    sn.step = Step.SELECT;
                    break;
                }
                case EXCLUDE: {
                    if (sn.prepareStep(doc, Decision.EXCLUDED)) {
                        sn.step = Step.POST_EXCLUDE;
                        sn = sn.excluded.child;
                        break;
                    }
                    sn.step = Step.FINAL;
                    break;
                }
                case POST_EXCLUDE: {
                    sn.excluded.postStep(returnWeight, returnWeightSum);
                    sn.step = Step.SELECT;
                    break;
                }
                case FINAL: {
                    returnWeight = sn.finalStep();
                    returnWeightSum = sn.getWeightSum();
                    sn = sn.parent;
                    break;
                }
            }
        } while (sn != null);
    }

    public void setWeight(double w) {
        for (Option sc : this.modifiedActs.values()) {
            sc.setWeight(w);
        }
    }

    private static void checkTimeoutCondition(Long timeoutInMilliSeconds, long startTime) throws TimeoutException {
        if (timeoutInMilliSeconds != null && System.currentTimeMillis() > startTime + timeoutInMilliSeconds) {
            throw new TimeoutException("Interpretation search took too long: " + (System.currentTimeMillis() - startTime) + "ms");
        }
    }

    public double getWeightSum() {
        return this.selected.weightSum + this.excluded.weightSum;
    }

    private void initStep(Document doc) throws Activation.RecursiveDepthExceededException {
        Decision cd;
        this.act = doc.candidates.get(this.level);
        if (OPTIMIZE_SEARCH && (cd = this.getCachedDecision()) != null && cd != Decision.UNKNOWN) {
            SearchNode asn;
            this.getBranch((Decision)cd).weightSum = this.act.alternativeCachedWeightSum;
            if (COMPUTE_SOFT_MAX && (asn = this.act.cachedSearchNode.getAlternative()) != null) {
                ++asn.cachedCount;
            }
        }
        if (doc.searchStepCounter > MAX_SEARCH_STEPS) {
            this.dumpDebugState();
            throw new RuntimeException("Max search step exceeded.");
        }
        ++doc.searchStepCounter;
        this.storeDebugInfos();
    }

    private Decision getCachedDecision() {
        return this.act.cachedDecision;
    }

    private boolean prepareStep(Document doc, Decision d) throws Activation.OscillatingActivationsException {
        Branch b = this.getBranch(d);
        if (b.visited) {
            return false;
        }
        b.visited = true;
        if (OPTIMIZE_SEARCH && this.getCachedDecision() == d.getInverted() && (this.selected.searched || d == Decision.SELECTED)) {
            return false;
        }
        SearchNode child = new SearchNode(doc, d, this, this.level + 1);
        if (b.prepareStep(doc, child)) {
            return false;
        }
        if (d == Decision.SELECTED && this.act.cachedDecision == Decision.UNKNOWN) {
            this.invalidateCachedDecisions();
        }
        int n = d.ordinal();
        this.act.debugDecisionCounts[n] = this.act.debugDecisionCounts[n] + 1;
        return true;
    }

    private double finalStep() {
        Decision d;
        Decision decision = d = this.selected.weight >= this.excluded.weight ? Decision.SELECTED : Decision.EXCLUDED;
        if (this.selected.searched && this.excluded.searched) {
            this.act.cachedDecision = d;
            this.act.alternativeCachedWeightSum = this.getBranch((Decision)this.act.cachedDecision).weightSum;
        }
        Branch b = this.getBranch(d);
        SearchNode cn = b.child;
        if (cn != null && cn.bestPath) {
            this.act.bestChildNode = cn;
            this.bestPath = true;
        }
        if (!COMPUTE_SOFT_MAX) {
            if (!this.bestPath) {
                b.child = null;
            }
            this.getBranch((Decision)d.getInverted()).child = null;
        }
        return b.weight;
    }

    private void invalidateCachedDecisions() {
        this.act.getOutputLinks().filter(l -> !l.isNegative(Synapse.State.CURRENT)).forEach(l -> SearchNode.invalidateCachedDecision(l.getOutput()));
    }

    public static void invalidateCachedDecision(Activation act) {
        if (act != null && act.cachedDecision == Decision.EXCLUDED) {
            act.cachedDecision = Decision.UNKNOWN;
            SearchNode pn = act.cachedSearchNode.parent;
            if (pn != null) {
                pn.selected.repeat();
            }
        }
        act.getInputLinks().filter(l -> l.isRecurrent() && l.isNegative(Synapse.State.CURRENT)).map(l -> l.getInput()).filter(c -> c.cachedDecision == Decision.SELECTED).forEach(c -> {
            c.cachedDecision = Decision.UNKNOWN;
        });
    }

    private double processResult(Document doc) {
        double accNW = this.accumulatedWeight;
        if (this.level > doc.selectedSearchNode.level || accNW > this.getSelectedAccumulatedWeight(doc)) {
            doc.selectedSearchNode = this;
            doc.storeFinalState();
            this.bestPath = true;
        } else {
            this.bestPath = false;
        }
        return this.accumulatedWeight;
    }

    public static void computeCachedFactor(SearchNode sn) {
        while (sn != null) {
            switch (sn.currentChildDecision) {
                case UNKNOWN: {
                    sn.currentChildDecision = Decision.SELECTED;
                    if (sn.selected.child == null) break;
                    sn = sn.selected.child;
                    sn.computeCacheFactor();
                    break;
                }
                case SELECTED: {
                    sn.currentChildDecision = Decision.EXCLUDED;
                    if (sn.excluded.child == null) break;
                    sn = sn.excluded.child;
                    sn.computeCacheFactor();
                    break;
                }
                case EXCLUDED: {
                    sn = sn.parent;
                }
            }
        }
    }

    private void computeCacheFactor() {
        this.cachedFactor = (this.parent != null ? this.parent.cachedFactor : 1) * this.cachedCount;
        for (Option sc : this.modifiedActs.values()) {
            sc.setCacheFactor(this.cachedFactor);
        }
    }

    private double getSelectedAccumulatedWeight(Document doc) {
        return doc.selectedSearchNode != null ? doc.selectedSearchNode.accumulatedWeight : -1.0;
    }

    public void changeState(Activation.Mode m) {
        this.modifiedActs.values().forEach(sc -> sc.restoreState(m));
    }

    @Override
    public int compareTo(SearchNode sn) {
        return Integer.compare(this.id, sn.id);
    }

    public Decision getDecision() {
        return this.decision;
    }

    private void storeDebugInfos() {
        this.debugState = this.getDebugState();
        int n = this.debugState.ordinal();
        this.act.debugCounts[n] = this.act.debugCounts[n] + 1;
    }

    private DebugState getDebugState() {
        if (!this.selected.searched || !this.excluded.searched) {
            return DebugState.LIMITED;
        }
        if (this.getCachedDecision() != Decision.UNKNOWN) {
            return DebugState.CACHED;
        }
        return DebugState.EXPLORE;
    }

    public void dumpDebugState() {
        SearchNode n = this;
        Object weights = "";
        Decision decision = Decision.UNKNOWN;
        while (n != null && n.level >= 0) {
            log.info(n.level + " " + n.debugState + " DECISION:" + decision + (String)weights + " " + (n.act != null ? n.act.toString() : "") + " MOD-ACTS:" + n.modifiedActs.size());
            decision = n.decision;
            weights = " AW:" + Utils.round(n.accumulatedWeight) + " DW:" + Utils.round(n.weightDelta);
            n = n.parent;
        }
    }

    public String toString() {
        return "id:" + this.id + " actId:" + (Serializable)(this.act != null ? Integer.valueOf(this.act.getId()) : "-") + " Decision:" + this.getDecision() + " curDec:" + this.currentChildDecision;
    }

    public static class TimeoutException
    extends RuntimeException {
        public TimeoutException(String message) {
            super(message);
        }
    }

    public static enum DebugState {
        CACHED,
        LIMITED,
        EXPLORE;

    }

    private static enum Step {
        INIT,
        SELECT,
        POST_SELECT,
        EXCLUDE,
        POST_EXCLUDE,
        FINAL;

    }
}

