/*
 * Decompiled with CFR 0.152.
 */
package org.aika.neuron.activation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.TreeSet;
import org.aika.Document;
import org.aika.Utils;
import org.aika.lattice.Node;
import org.aika.lattice.NodeActivation;
import org.aika.lattice.OrNode;
import org.aika.neuron.INeuron;
import org.aika.neuron.Neuron;
import org.aika.neuron.Synapse;
import org.aika.neuron.activation.Candidate;
import org.aika.neuron.activation.Conflicts;
import org.aika.neuron.activation.Linker;
import org.aika.neuron.activation.Range;
import org.aika.neuron.activation.SearchNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class Activation
extends NodeActivation<OrNode> {
    public static final Comparator<Activation> ACTIVATION_ID_COMP = Comparator.comparingInt(act -> act.id);
    public static int MAX_SELF_REFERENCING_DEPTH = 5;
    private static final Logger log = LoggerFactory.getLogger(Activation.class);
    public TreeSet<SynapseActivation> selectedNeuronInputs = new TreeSet<SynapseActivation>(SynapseActivation.INPUT_COMP);
    public TreeSet<SynapseActivation> neuronInputs = new TreeSet<SynapseActivation>(SynapseActivation.INPUT_COMP);
    public TreeSet<SynapseActivation> neuronOutputs = new TreeSet<SynapseActivation>(SynapseActivation.OUTPUT_COMP);
    public Integer sequence;
    public double upperBound;
    public double lowerBound;
    public Rounds rounds;
    public Rounds finalRounds = this.rounds = new Rounds();
    public boolean ubQueued = false;
    public boolean isQueued = false;
    public long queueId;
    public long markedHasCandidate;
    public long currentStateV;
    public StateChange currentStateChange;
    public long markedDirty;
    public double errorSignal;
    public Double targetValue;
    public Double inputValue;
    public SearchNode.Decision inputDecision = SearchNode.Decision.UNKNOWN;
    public SearchNode.Decision decision = SearchNode.Decision.UNKNOWN;
    public SearchNode.Decision finalDecision = SearchNode.Decision.UNKNOWN;
    public Candidate candidate;
    public Conflicts conflicts = new Conflicts();
    public long markedConflict;
    private long visitedState;

    public Activation(int id, Document doc, NodeActivation.Key key) {
        super(id, doc, key);
    }

    public void setTargetValue(Double targetValue) {
        this.targetValue = targetValue;
        if (targetValue != null) {
            this.doc.supervisedTraining.targetActivations.add(this);
        } else {
            this.doc.supervisedTraining.targetActivations.remove(this);
        }
    }

    public String getLabel() {
        return this.getINeuron().label;
    }

    public String getText() {
        return this.doc.getText(this.key.range);
    }

    public INeuron getINeuron() {
        return (INeuron)this.getNeuron().get(this.doc);
    }

    public Neuron getNeuron() {
        return ((OrNode)this.key.node).neuron;
    }

    public void addSynapseActivation(Linker.Direction dir, SynapseActivation sa) {
        switch (dir) {
            case INPUT: {
                this.neuronOutputs.add(sa);
                break;
            }
            case OUTPUT: {
                if (sa.input.decision == SearchNode.Decision.SELECTED) {
                    this.selectedNeuronInputs.add(sa);
                }
                this.neuronInputs.add(sa);
            }
        }
    }

    public SearchNode.Weight process(SearchNode sn, int round, long v) {
        SearchNode.Weight delta = SearchNode.Weight.ZERO;
        State s = this.inputValue != null ? new State(this.inputValue, 0.0, 0, SearchNode.Weight.ZERO) : this.computeValueAndWeight(round);
        if (round == 0 || !this.rounds.get(round).equalsWithWeights(s)) {
            this.saveOldState(sn.modifiedActs, v);
            State oldState = this.rounds.get(round);
            boolean propagate = this.rounds.set(round, s) && (oldState == null || !oldState.equals(s));
            this.saveNewState();
            if (propagate) {
                if (round > Document.MAX_ROUND) {
                    log.error("Error: Maximum number of rounds reached. The network might be oscillating.");
                    log.info(this.doc.activationsToString(false, true, true));
                    this.doc.dumpOscillatingActivations();
                    throw new RuntimeException("Maximum number of rounds reached. The network might be oscillating.");
                }
                this.doc.vQueue.propagateActivationValue(round, this);
            }
            if (round == 0) {
                this.doc.vQueue.add(1, this);
            }
            if (this.rounds.getLastRound() != null && round >= this.rounds.getLastRound()) {
                delta = delta.add(s.weight.sub(oldState.weight));
            }
        }
        return delta;
    }

    public State computeValueAndWeight(int round) {
        INeuron n = this.getINeuron();
        double net = n.biasSum;
        double netDir = n.biasSum;
        int fired = -1;
        for (InputState is : this.getInputStates(round)) {
            Synapse s = is.sa.synapse;
            Activation iAct = is.sa.input;
            if (iAct == this) continue;
            double x = is.s.value * s.weight;
            net += x;
            if (!s.key.isRecurrent) {
                netDir += x;
            }
            if (s.key.isRecurrent || s.isNegative() || !(net >= 0.0) || fired >= 0) continue;
            fired = iAct.rounds.get((int)round).fired + 1;
        }
        double currentActValue = n.activationFunction.f(net);
        double w = Math.min(-n.negRecSum, net);
        double norm = Math.min(-n.negRecSum, netDir + n.posRecSum);
        SearchNode.Weight newWeight = SearchNode.Weight.create(this.decision == SearchNode.Decision.SELECTED ? Math.max(0.0, w) : 0.0, Math.max(0.0, norm));
        if (this.decision == SearchNode.Decision.SELECTED || INeuron.ALLOW_WEAK_NEGATIVE_WEIGHTS) {
            return new State(currentActValue, net, -1, newWeight);
        }
        return new State(0.0, 0.0, -1, newWeight);
    }

    public void processBounds() {
        double oldUpperBound = this.upperBound;
        this.computeBounds();
        if (Math.abs(this.upperBound - oldUpperBound) > 0.01) {
            for (SynapseActivation sa : this.neuronOutputs) {
                this.doc.ubQueue.add(sa.output);
            }
        }
        if (oldUpperBound <= 0.0 && this.upperBound > 0.0) {
            this.getINeuron().propagate(this);
        }
    }

    public void computeBounds() {
        INeuron n = this.getINeuron();
        double ub = n.biasSum + n.posRecSum;
        double lb = n.biasSum + n.posRecSum;
        for (SynapseActivation sa : this.neuronInputs) {
            Activation iAct;
            Synapse s = sa.synapse;
            if (s.inactive || (iAct = sa.input) == this) continue;
            if (s.isNegative()) {
                if (!s.key.isRecurrent && !Activation.checkSelfReferencing(this, iAct, false, 0)) {
                    ub += iAct.lowerBound * s.weight;
                }
                lb += s.weight;
                continue;
            }
            ub += iAct.upperBound * s.weight;
            lb += iAct.lowerBound * s.weight;
        }
        this.upperBound = n.activationFunction.f(ub);
        this.lowerBound = n.activationFunction.f(lb);
    }

    private static State getInitialState(SearchNode.Decision c) {
        return new State(c == SearchNode.Decision.SELECTED ? 1.0 : 0.0, 0.0, 0, SearchNode.Weight.ZERO);
    }

    private List<InputState> getInputStates(int round) {
        ArrayList<InputState> tmp = new ArrayList<InputState>();
        Synapse lastSynapse = null;
        InputState maxInputState = null;
        for (SynapseActivation sa : this.neuronInputs) {
            if (sa.synapse.inactive) continue;
            if (lastSynapse != null && lastSynapse != sa.synapse) {
                tmp.add(maxInputState);
                maxInputState = null;
            }
            State s = sa.input.getInputState(round, this, sa.synapse);
            if (maxInputState == null || maxInputState.s.value < s.value) {
                maxInputState = new InputState(sa, s);
            }
            lastSynapse = sa.synapse;
        }
        if (maxInputState != null) {
            tmp.add(maxInputState);
        }
        return tmp;
    }

    private State getInputState(int round, Activation act, Synapse s) {
        State is = State.ZERO;
        if (s.key.isRecurrent) {
            if (!s.isNegative() || !Activation.checkSelfReferencing(act, this, true, 0)) {
                is = round == 0 ? Activation.getInitialState(this.decision) : this.rounds.get(round - 1);
            }
        } else {
            is = this.rounds.get(round);
        }
        return is;
    }

    public List<SynapseActivation> getFinalInputActivations() {
        ArrayList<SynapseActivation> results = new ArrayList<SynapseActivation>();
        for (SynapseActivation inputAct : this.neuronInputs) {
            if (!inputAct.input.isFinalActivation()) continue;
            results.add(inputAct);
        }
        return results;
    }

    public List<SynapseActivation> getFinalOutputActivations() {
        ArrayList<SynapseActivation> results = new ArrayList<SynapseActivation>();
        for (SynapseActivation inputAct : this.neuronOutputs) {
            if (!inputAct.output.isFinalActivation()) continue;
            results.add(inputAct);
        }
        return results;
    }

    public void adjustSelectedNeuronInputs(SearchNode.Decision d) {
        for (SynapseActivation sa : this.neuronOutputs) {
            if (d == SearchNode.Decision.SELECTED) {
                sa.output.selectedNeuronInputs.add(sa);
                continue;
            }
            sa.output.selectedNeuronInputs.remove(sa);
        }
    }

    public static boolean checkSelfReferencing(Activation nx, Activation ny, boolean onlySelected, int depth) {
        if (nx == ny) {
            return true;
        }
        if (depth > MAX_SELF_REFERENCING_DEPTH) {
            return false;
        }
        for (SynapseActivation sa : onlySelected ? ny.selectedNeuronInputs : ny.neuronInputs) {
            if (sa.synapse.key.isRecurrent || !Activation.checkSelfReferencing(nx, sa.input, onlySelected, depth + 1)) continue;
            return true;
        }
        return false;
    }

    public void setDecision(SearchNode.Decision newDecision, long v) {
        if (this.inputDecision != SearchNode.Decision.UNKNOWN && newDecision != this.inputDecision) {
            return;
        }
        if (newDecision == SearchNode.Decision.UNKNOWN && v != this.visitedState) {
            return;
        }
        if (this.decision == SearchNode.Decision.SELECTED != (newDecision == SearchNode.Decision.SELECTED)) {
            this.adjustSelectedNeuronInputs(newDecision);
        }
        this.decision = newDecision;
        this.visitedState = v;
    }

    public boolean isFinalActivation() {
        return this.getFinalState().value > 0.0;
    }

    public State getFinalState() {
        return this.finalRounds.getLast();
    }

    public <T extends Node> boolean filter(T n, Integer rid, Range r, Range.Relation rr) {
        return !(n != null && this.key.node != n || rid != null && (this.key.rid == null || this.key.rid.intValue() != rid.intValue()) || r != null && rr != null && !rr.compare(this.key.range, r));
    }

    public Integer getSequence() {
        if (this.sequence != null) {
            return this.sequence;
        }
        this.sequence = 0;
        this.neuronInputs.stream().filter(sa -> !sa.synapse.key.isRecurrent).forEach(sa -> {
            this.sequence = Math.max(this.sequence, sa.input.getSequence() + 1);
        });
        return this.sequence;
    }

    public void markDirty(long v) {
        this.markedDirty = Math.max(this.markedDirty, v);
    }

    public String toString() {
        return this.key + " -" + " UB:" + Utils.round(this.upperBound) + (this.inputValue != null ? " IV:" + Utils.round(this.inputValue) : "") + (this.targetValue != null ? " TV:" + Utils.round(this.targetValue) : "") + " V:" + Utils.round(this.rounds.getLast().value) + " FV:" + Utils.round(this.finalRounds.getLast().value);
    }

    public String toString(boolean finalOnly, boolean withTextSnippet, boolean withLogic) {
        StringBuilder sb = new StringBuilder();
        sb.append(this.id + " - ");
        sb.append((Object)((Object)(finalOnly ? this.finalDecision : this.decision)) + " - ");
        sb.append(this.key.range);
        if (withTextSnippet) {
            sb.append(" \"");
            if (((INeuron)((OrNode)this.key.node).neuron.get()).outputText != null) {
                sb.append(Utils.collapseText(((INeuron)((OrNode)this.key.node).neuron.get()).outputText, 7));
            } else {
                sb.append(Utils.collapseText(this.doc.getText(this.key.range), 7));
            }
            sb.append("\"");
        }
        sb.append(" - ");
        sb.append(withLogic ? ((OrNode)this.key.node).toString() : ((OrNode)this.key.node).getNeuronLabel());
        sb.append(" - RID:");
        sb.append(this.key.rid);
        sb.append(" - UB:");
        sb.append(Utils.round(this.upperBound));
        sb.append(" - ");
        if (finalOnly) {
            if (this.isFinalActivation()) {
                State fs = this.getFinalState();
                sb.append(fs);
            }
        } else {
            for (Map.Entry<Integer, State> me : this.rounds.rounds.entrySet()) {
                State s = me.getValue();
                sb.append("[R: " + me.getKey() + " " + s + "]");
            }
        }
        if (this.inputValue != null) {
            sb.append(" - IV:" + Utils.round(this.inputValue));
        }
        if (this.targetValue != null) {
            sb.append(" - TV:" + Utils.round(this.targetValue));
        }
        return sb.toString();
    }

    public String linksToString() {
        StringBuilder sb = new StringBuilder();
        for (SynapseActivation sa : this.neuronInputs) {
            sb.append("  " + sa.input.getLabel() + "  W:" + sa.synapse.weight + "\n");
        }
        return sb.toString();
    }

    public void saveOldState(Map<Activation, StateChange> changes, long v) {
        StateChange sc = this.currentStateChange;
        if (sc == null || this.currentStateV != v) {
            sc = new StateChange();
            sc.oldRounds = this.rounds.copy();
            this.currentStateChange = sc;
            this.currentStateV = v;
            if (changes != null) {
                changes.put(sc.getActivation(), sc);
            }
        }
    }

    public void saveNewState() {
        StateChange sc = this.currentStateChange;
        sc.newRounds = this.rounds.copy();
        sc.newState = this.decision;
    }

    public static class Builder {
        public Range range;
        public Integer rid;
        public double value = 1.0;
        public Double targetValue;
        public int fired;

        public Builder setRange(int begin, int end) {
            this.range = new Range(begin, end);
            return this;
        }

        public Builder setRange(Range range) {
            this.range = range;
            return this;
        }

        public Builder setRelationalId(Integer rid) {
            this.rid = rid;
            return this;
        }

        public Builder setValue(double value) {
            this.value = value;
            return this;
        }

        public Builder setTargetValue(Double targetValue) {
            this.targetValue = targetValue;
            return this;
        }

        public Builder setFired(int fired) {
            this.fired = fired;
            return this;
        }
    }

    public class StateChange {
        public Rounds oldRounds;
        public Rounds newRounds;
        public SearchNode.Decision newState;

        public void restoreState(Mode m) {
            Activation.this.rounds = (m == Mode.OLD ? this.oldRounds : this.newRounds).copy();
        }

        public Activation getActivation() {
            return Activation.this;
        }
    }

    public static enum Mode {
        OLD,
        NEW;

    }

    public static class State {
        public static final int DIR = 0;
        public static final int REC = 1;
        public final double value;
        public final double net;
        public final int fired;
        public final SearchNode.Weight weight;
        public static final State ZERO = new State(0.0, 0.0, -1, SearchNode.Weight.ZERO);

        public State(double value, double net, int fired, SearchNode.Weight weight) {
            assert (!Double.isNaN(value));
            this.value = value;
            this.net = net;
            this.fired = fired;
            this.weight = weight;
        }

        public boolean equals(State s) {
            return Math.abs(this.value - s.value) <= INeuron.WEIGHT_TOLERANCE;
        }

        public boolean equalsWithWeights(State s) {
            return this.equals(s) && this.weight.equals(s.weight);
        }

        public String toString() {
            return "V:" + Utils.round(this.value) + " " + this.weight;
        }
    }

    public static class Rounds {
        private boolean[] isQueued = new boolean[3];
        public TreeMap<Integer, State> rounds = new TreeMap();

        public Rounds() {
            this.rounds.put(0, State.ZERO);
        }

        public boolean set(int r, State s) {
            State lr = this.get(r - 1);
            if (lr != null && lr.equalsWithWeights(s)) {
                State or = this.rounds.get(r);
                if (or != null) {
                    this.rounds.remove(r);
                    return !or.equalsWithWeights(s);
                }
                return false;
            }
            State or = this.rounds.put(r, s);
            Iterator<Map.Entry<Integer, State>> it = this.rounds.tailMap(r + 1).entrySet().iterator();
            while (it.hasNext()) {
                Map.Entry<Integer, State> me = it.next();
                if (!me.getValue().equalsWithWeights(s)) continue;
                it.remove();
            }
            return or == null || !or.equalsWithWeights(s);
        }

        public State get(int r) {
            Map.Entry<Integer, State> me = this.rounds.floorEntry(r);
            return me != null ? me.getValue() : null;
        }

        public Rounds copy() {
            Rounds nr = new Rounds();
            nr.rounds.putAll(this.rounds);
            return nr;
        }

        public Integer getLastRound() {
            return !this.rounds.isEmpty() ? this.rounds.lastKey() : null;
        }

        public State getLast() {
            return !this.rounds.isEmpty() ? this.rounds.lastEntry().getValue() : State.ZERO;
        }

        public void setQueued(int r, boolean v) {
            if (r >= this.isQueued.length) {
                this.isQueued = Arrays.copyOf(this.isQueued, this.isQueued.length * 2);
            }
            this.isQueued[r] = v;
        }

        public boolean isQueued(int r) {
            return r < this.isQueued.length ? this.isQueued[r] : false;
        }

        public void reset() {
            this.rounds.clear();
            this.rounds.put(0, State.ZERO);
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            this.rounds.forEach((r, s) -> sb.append(r + ":" + s.value + " "));
            return sb.toString();
        }

        public boolean compare(Rounds r) {
            if (this.rounds.size() != r.rounds.size()) {
                return false;
            }
            for (Map.Entry<Integer, State> me : this.rounds.entrySet()) {
                State sa = me.getValue();
                State sb = r.rounds.get(me.getKey());
                if (sb != null && !(Math.abs(sa.value - sb.value) > 1.0E-7)) continue;
                return false;
            }
            return true;
        }

        public boolean isActive() {
            return this.rounds.size() <= 1 && this.getLast().value > 0.0;
        }
    }

    public static class SynapseActivation {
        public final Synapse synapse;
        public final Activation input;
        public final Activation output;
        public static Comparator<SynapseActivation> INPUT_COMP = (sa1, sa2) -> {
            int r = Synapse.INPUT_SYNAPSE_COMP.compare(sa1.synapse, sa2.synapse);
            if (r != 0) {
                return r;
            }
            return sa1.input.compareTo(sa2.input);
        };
        public static Comparator<SynapseActivation> OUTPUT_COMP = (sa1, sa2) -> {
            int r = Synapse.OUTPUT_SYNAPSE_COMP.compare(sa1.synapse, sa2.synapse);
            if (r != 0) {
                return r;
            }
            return sa1.output.compareTo(sa2.output);
        };

        public SynapseActivation(Synapse s, Activation input, Activation output) {
            this.synapse = s;
            this.input = input;
            this.output = output;
        }
    }

    private static class InputState {
        SynapseActivation sa;
        State s;

        public InputState(SynapseActivation sa, State s) {
            this.sa = sa;
            this.s = s;
        }
    }
}

