/*
 * Decompiled with CFR 0.152.
 */
package network.aika;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import network.aika.AbstractNode;
import network.aika.Converter;
import network.aika.Model;
import network.aika.Provider;
import network.aika.lattice.Node;
import network.aika.lattice.NodeActivation;
import network.aika.lattice.OrNode;
import network.aika.neuron.INeuron;
import network.aika.neuron.Synapse;
import network.aika.neuron.activation.Activation;
import network.aika.neuron.activation.Candidate;
import network.aika.neuron.activation.Linker;
import network.aika.neuron.activation.Position;
import network.aika.neuron.activation.SearchNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Document
implements Comparable<Document> {
    private static final Logger log = LoggerFactory.getLogger(Document.class);
    public static int CLEANUP_INTERVAL = 500;
    public static int MAX_ROUND = 20;
    public static int ROUND_LIMIT = -1;
    public static boolean INCREMENTAL_MODE = false;
    public final int id;
    private final StringBuilder content;
    public long visitedCounter = 1L;
    public int activationIdCounter = 0;
    public int logicNodeActivationIdCounter = 0;
    public int searchNodeIdCounter = 0;
    public int searchStepCounter = 0;
    public int positionIdCounter = 0;
    public Model model;
    public int threadId;
    public Queue queue = new Queue();
    public ValueQueue vQueue = new ValueQueue();
    public UpperBoundQueue ubQueue = new UpperBoundQueue();
    public Linker linker;
    public TreeMap<Integer, Position> positions = new TreeMap();
    public TreeSet<Node> activatedNodes = new TreeSet();
    public TreeSet<INeuron> activatedNeurons = new TreeSet();
    public TreeSet<INeuron> finallyActivatedNeurons = new TreeSet();
    public TreeSet<Activation> inputNeuronActivations = new TreeSet();
    public TreeMap<INeuron, Set<Synapse>> modifiedWeights = new TreeMap();
    private TreeMap<ActKey, Activation> activationsBySlotAndPosition = new TreeMap((ak1, ak2) -> {
        int r = Integer.compare(ak1.slot, ak2.slot);
        if (r != 0) {
            return r;
        }
        r = Position.compare(ak1.pos, ak2.pos);
        if (r != 0) {
            return r;
        }
        r = ak1.node.compareTo(ak2.node);
        if (r != 0) {
            return r;
        }
        return Integer.compare(ak1.actId, ak2.actId);
    });
    private TreeMap<ActKey, Activation> activationsByPosition = new TreeMap((ak1, ak2) -> {
        int r = Position.compare(ak1.pos, ak2.pos);
        if (r != 0) {
            return r;
        }
        r = ak1.node.compareTo(ak2.node);
        if (r != 0) {
            return r;
        }
        return Integer.compare(ak1.actId, ak2.actId);
    });
    public TreeMap<Integer, Activation> activationsById = new TreeMap();
    private int lastProcessedActivationId = -1;
    public TreeSet<Node> addedNodes = new TreeSet();
    public ArrayList<NodeActivation> addedNodeActivations = new ArrayList();
    public SearchNode selectedSearchNode;
    public ArrayList<Candidate> candidates = new ArrayList();
    public long createV;
    public static Comparator<Activation> ACTIVATIONS_OUTPUT_COMPARATOR = (act1, act2) -> {
        int r = Position.compare(act1.getSlot(Activation.BEGIN), act2.getSlot(Activation.BEGIN));
        if (r != 0) {
            return r;
        }
        r = ((OrNode)act1.node).compareTo(act2.node);
        if (r != 0) {
            return r;
        }
        return Integer.compare(act1.id, act2.id);
    };
    private static Comparator<Activation> VALUE_QUEUE_COMP = (a, b) -> {
        int r = Integer.compare(a.getSequence(), b.getSequence());
        if (r != 0) {
            return r;
        }
        return Integer.compare(a.id, b.id);
    };

    public Document(int id, String content, Model model, int threadId) {
        this.id = id;
        this.content = new StringBuilder(content);
        this.model = model;
        this.threadId = threadId;
        this.linker = model.getLinkerFactory().createLinker(this);
    }

    public void append(String txt) {
        this.content.append(txt);
    }

    public char charAt(int i) {
        return this.content.charAt(i);
    }

    public String getContent() {
        return this.content.toString();
    }

    public int length() {
        return this.content.length();
    }

    public String toString() {
        return this.content.toString();
    }

    public Position lookupFinalPosition(int pos) {
        Position p = this.positions.get(pos);
        if (p == null) {
            p = new Position(this, pos);
            this.positions.put(pos, p);
        }
        return p;
    }

    public String getText(Position begin, Position end) {
        return this.getText(begin.getFinalPosition(), end.getFinalPosition());
    }

    public String getText(Integer begin, Integer end) {
        if (begin != null && end != null) {
            return this.content.substring(Math.max(0, Math.min(begin, this.length())), Math.max(0, Math.min(end, this.length())));
        }
        return "";
    }

    public void addActivation(Activation act) {
        for (Map.Entry<Integer, Position> me : act.slots.entrySet()) {
            Position pos = me.getValue();
            if (pos == null || pos.getFinalPosition() == null) continue;
            ActKey dak = new ActKey(me.getKey(), pos, act.node, act.id);
            this.activationsBySlotAndPosition.put(dak, act);
            this.activationsByPosition.put(dak, act);
        }
        this.activationsById.put(act.id, act);
    }

    public Collection<Activation> getActivations(boolean onlyFinal) {
        if (!onlyFinal) {
            return this.activationsById.values();
        }
        return this.activationsById.values().stream().filter(act -> act.isFinalActivation()).collect(Collectors.toList());
    }

    public Collection<Activation> getActivationsByPosition(int fromSlot, Position fromPos, boolean fromInclusive, int toSlot, Position toPos, boolean toInclusive) {
        return this.activationsBySlotAndPosition.subMap(new ActKey(fromSlot, fromPos, Node.MIN_NODE, Integer.MIN_VALUE), fromInclusive, new ActKey(toSlot, toPos, Node.MAX_NODE, Integer.MAX_VALUE), toInclusive).values();
    }

    public Collection<Activation> getActivationsByPosition(Position fromPos, boolean fromInclusive, Position toPos, boolean toInclusive) {
        return this.activationsByPosition.subMap(new ActKey(-1, fromPos, Node.MIN_NODE, Integer.MIN_VALUE), fromInclusive, new ActKey(-1, toPos, Node.MAX_NODE, Integer.MAX_VALUE), toInclusive).values();
    }

    public Activation getNextActivation(Activation currentAct) {
        Map.Entry<Integer, Activation> me = currentAct == null ? this.activationsById.firstEntry() : this.activationsById.higherEntry(currentAct.id);
        return me != null ? me.getValue() : null;
    }

    public int getNumberOfActivations() {
        return this.activationsById.size();
    }

    @Override
    public int compareTo(Document doc) {
        return Integer.compare(this.id, doc.id);
    }

    public void propagate() {
        boolean flag = true;
        while (flag) {
            this.queue.processChanges();
            flag = this.ubQueue.process();
        }
    }

    public void generateCandidates() {
        TreeSet<Candidate> tmp = new TreeSet<Candidate>();
        int i = 0;
        if (!INCREMENTAL_MODE) {
            this.candidates.clear();
        }
        for (Activation act : this.activationsById.subMap(INCREMENTAL_MODE ? this.lastProcessedActivationId : -1, false, Integer.MAX_VALUE, true).values()) {
            if (act.decision != SearchNode.Decision.UNKNOWN || !(act.upperBound > 0.0)) continue;
            SearchNode.invalidateCachedDecision(act);
            tmp.add(new Candidate(act, i++));
            this.lastProcessedActivationId = Math.max(this.lastProcessedActivationId, act.id);
        }
        long v = this.visitedCounter++;
        for (Activation act : this.inputNeuronActivations) {
            act.markedHasCandidate = v;
        }
        while (!tmp.isEmpty()) {
            int oldSize = tmp.size();
            for (Candidate c : tmp) {
                if (!c.checkDependenciesSatisfied(v)) continue;
                tmp.remove(c);
                c.id = this.candidates.size();
                this.candidates.add(c);
                c.activation.markedHasCandidate = v;
                break;
            }
            if (tmp.size() != oldSize) continue;
            log.error("Cycle detected in the activations that is not marked recurrent.");
            throw new RuntimeException("Cycle detected in the activations that is not marked recurrent.");
        }
    }

    public void process() {
        this.process(null);
    }

    public void process(Long timeoutInMilliSeconds) throws SearchNode.TimeoutException {
        this.linker.lateLinking();
        this.inputNeuronActivations.forEach(act -> this.vQueue.propagateActivationValue(0, (Activation)act));
        this.generateCandidates();
        SearchNode rootNode = null;
        if (this.selectedSearchNode == null || !INCREMENTAL_MODE) {
            rootNode = this.selectedSearchNode = new SearchNode(this, null, null, 0);
        }
        SearchNode.search(this, this.selectedSearchNode, this.visitedCounter++, timeoutInMilliSeconds);
        for (Activation act2 : this.activationsById.values()) {
            if (!act2.isFinalActivation()) continue;
            this.finallyActivatedNeurons.add(act2.getINeuron());
        }
        if (SearchNode.COMPUTE_SOFT_MAX) {
            this.computeSoftMax(rootNode);
        }
    }

    private void computeSoftMax(SearchNode rootNode) {
        double norm = rootNode.getWeightExpSum();
        for (Activation act : this.activationsById.values()) {
            if (act.searchStates == null) continue;
            double avgValue = 0.0;
            double avgPosValue = 0.0;
            double avgP = 0.0;
            double avgNet = 0.0;
            double avgPosNet = 0.0;
            for (Activation.AvgState avgState : act.searchStates) {
                double p = avgState.weight / norm;
                Activation.State s = avgState.state;
                avgValue += p * s.value;
                avgPosValue += p * s.posValue;
                avgP += p * s.p;
                avgNet += p * s.net;
                avgPosNet += p * s.posNet;
            }
            act.avgState = new Activation.State(avgValue, avgPosValue, avgP, avgNet, avgPosNet, 0, 0.0);
        }
    }

    public void dumpDebugCandidateStatistics() {
        for (Candidate c : this.candidates) {
            log.info(c.toString());
        }
    }

    public void notifyWeightModified(Synapse synapse) {
        Set<Synapse> is = this.modifiedWeights.get(synapse.output.get());
        if (is == null) {
            is = new TreeSet<Synapse>(Synapse.INPUT_SYNAPSE_COMP);
            this.modifiedWeights.put((INeuron)synapse.output.get(), is);
        }
        is.add(synapse);
    }

    public void commit() {
        this.modifiedWeights.forEach((n, inputSyns) -> Converter.convert(this.threadId, this, n, inputSyns));
        this.modifiedWeights.clear();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void clearActivations() {
        this.activatedNeurons.forEach(n -> n.clearActivations(this));
        this.activatedNodes.forEach(n -> n.clearActivations(this));
        this.activationsById.clear();
        this.addedNodeActivations.clear();
        this.activatedNeurons.clear();
        this.activatedNodes.clear();
        this.addedNodes.clear();
        if (this.model.lastCleanup[this.threadId] + CLEANUP_INTERVAL < this.id) {
            ArrayList<Provider<? extends AbstractNode>> tmp;
            this.model.lastCleanup[this.threadId] = this.id;
            Map<Integer, Provider<? extends AbstractNode>> map = this.model.activeProviders;
            synchronized (map) {
                tmp = new ArrayList<Provider<? extends AbstractNode>>(this.model.activeProviders.values());
            }
            tmp.forEach(np -> {
                Object an = np.getIfNotSuspended();
                if (an != null && an instanceof Node) {
                    Node n = (Node)an;
                    Node.ThreadState th = n.threads[this.threadId];
                    if (th != null && th.lastUsed + (long)CLEANUP_INTERVAL < (long)this.id) {
                        n.threads[this.threadId] = null;
                    }
                }
            });
        }
        this.model.docs[this.threadId] = null;
    }

    public String generateOutputText() {
        int oldLength = this.length();
        TreeSet<Position> queue = new TreeSet<Position>(Comparator.comparingInt(p -> p.id));
        for (Activation act2 : this.activationsById.values()) {
            if (act2.getINeuron().getOutputText() == null || act2.getSlot(Activation.BEGIN).getFinalPosition() == null || act2.getSlot(Activation.END).getFinalPosition() != null) continue;
            queue.add(act2.getSlot(Activation.BEGIN));
        }
        while (!queue.isEmpty()) {
            Position pos = queue.pollFirst();
            pos.getActivations(Activation.BEGIN).filter(act -> act.getINeuron().getOutputText() != null && act.isFinalActivation()).forEach(act -> {
                String outText = act.getINeuron().getOutputText();
                Position nextPos = act.getSlot(Activation.END);
                nextPos.setFinalPosition(pos.getFinalPosition() + outText.length());
                this.content.replace(act.getSlot(Activation.BEGIN).getFinalPosition(), act.getSlot(Activation.END).getFinalPosition(), outText);
                queue.add(nextPos);
            });
        }
        return this.content.substring(oldLength, this.length());
    }

    public String activationsToString() {
        return this.activationsToString(false);
    }

    public String activationsToString(boolean withLogic) {
        TreeSet<Activation> acts = new TreeSet<Activation>(ACTIVATIONS_OUTPUT_COMPARATOR);
        acts.addAll(this.activationsById.values());
        StringBuilder sb = new StringBuilder();
        sb.append("Id -");
        sb.append(" Decision -");
        sb.append(" Range | Text Snippet");
        sb.append(" | Identity -");
        sb.append(" Neuron Label -");
        sb.append(withLogic ? " Logic Layer -" : "");
        sb.append(" Upper Bound -");
        sb.append(" Value | Net | Weight -");
        sb.append(" Input Value |");
        sb.append(" Target Value");
        sb.append("\n");
        sb.append("\n");
        for (Activation act : acts) {
            if (act.upperBound <= 0.0 && (act.targetValue == null || act.targetValue <= 0.0)) continue;
            sb.append(act.toString(withLogic));
            sb.append("\n");
        }
        if (this.selectedSearchNode != null) {
            sb.append("\n Final SearchNode:" + this.selectedSearchNode.id + "  WeightSum:" + this.selectedSearchNode.accumulatedWeight + "\n");
        }
        return sb.toString();
    }

    public Stream<NodeActivation> getAllActivationsStream() {
        return this.addedNodeActivations.stream();
    }

    public void dumpOscillatingActivations() {
        this.activatedNeurons.stream().flatMap(n -> n.getActivations(this, false)).filter(act -> act.rounds.getLastRound() != null && act.rounds.getLastRound() > MAX_ROUND - 5).forEach(act -> {
            log.error(act.id + " " + act.slotsToString() + " " + (Object)((Object)act.decision) + " " + act.rounds);
            log.error(act.linksToString());
            log.error("");
        });
    }

    public class ValueQueue {
        public final ArrayList<TreeSet<Activation>> queue = new ArrayList();

        public void propagateActivationValue(int round, Activation act) {
            act.getOutputLinks(false).forEach(l -> this.add(l.synapse.isRecurrent ? round + 1 : round, l.output));
        }

        private void add(Activation act) {
            this.add(0, act);
            act.getOutputLinks(false).filter(l -> l.synapse.isRecurrent).forEach(l -> this.add(0, l.output));
        }

        public void add(int round, Activation act) {
            TreeSet<Activation> q;
            if (act.rounds.isQueued(round) || act.decision == SearchNode.Decision.UNKNOWN) {
                return;
            }
            if (round < this.queue.size()) {
                q = this.queue.get(round);
            } else {
                assert (round == this.queue.size());
                q = new TreeSet(VALUE_QUEUE_COMP);
                this.queue.add(q);
            }
            act.rounds.setQueued(round, true);
            q.add(act);
        }

        public double process(SearchNode sn) {
            long v = Document.this.visitedCounter++;
            if (sn.getParent() != null && sn.getParent().candidate != null) {
                this.add(sn.getParent().candidate.activation);
            }
            double delta = 0.0;
            for (int round = 0; round < this.queue.size(); ++round) {
                TreeSet<Activation> q = this.queue.get(round);
                while (!q.isEmpty()) {
                    Activation act = q.pollFirst();
                    act.rounds.setQueued(round, false);
                    delta += act.process(sn, round, v);
                }
            }
            return delta;
        }
    }

    public class UpperBoundQueue {
        public final ArrayDeque<Activation> queue = new ArrayDeque();

        public void add(Activation.Link l) {
            if (!l.synapse.isRecurrent) {
                this.add(l.output);
            }
        }

        public void add(Activation act) {
            if (!act.ubQueued && act.inputValue == null) {
                act.ubQueued = true;
                this.queue.addLast(act);
            }
        }

        public boolean process() {
            boolean flag = false;
            while (!this.queue.isEmpty()) {
                flag = true;
                Activation act = this.queue.pollFirst();
                act.ubQueued = false;
                act.processBounds();
            }
            return flag;
        }
    }

    public class Queue {
        public final TreeSet<Node> queue = new TreeSet<Node>(new Comparator<Node>(){

            @Override
            public int compare(Node n1, Node n2) {
                int r = Integer.compare(n1.level, n2.level);
                if (r != 0) {
                    return r;
                }
                Node.ThreadState th1 = n1.getThreadState(Document.this.threadId, true);
                Node.ThreadState th2 = n2.getThreadState(Document.this.threadId, true);
                return Long.compare(th1.queueId, th2.queueId);
            }
        });
        private long queueIdCounter = 0L;

        public void add(Node n) {
            Node.ThreadState th = n.getThreadState(Document.this.threadId, true);
            if (!th.isQueued) {
                th.isQueued = true;
                th.queueId = this.queueIdCounter++;
                this.queue.add(n);
            }
        }

        public void processChanges() {
            while (!this.queue.isEmpty()) {
                Node n = this.queue.pollFirst();
                Node.ThreadState th = n.getThreadState(Document.this.threadId, true);
                th.isQueued = false;
                n.processChanges(Document.this);
            }
        }
    }

    public static class ActKey {
        int slot;
        Position pos;
        Node node;
        int actId;

        public ActKey(int slot, Position pos, Node node, int actId) {
            this.slot = slot;
            this.pos = pos;
            this.node = node;
            this.actId = actId;
        }
    }
}

