/*
 * Decompiled with CFR 0.152.
 */
package org.aika.corpus;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.aika.AbstractNode;
import org.aika.Converter;
import org.aika.Model;
import org.aika.Provider;
import org.aika.Utils;
import org.aika.corpus.Candidate;
import org.aika.corpus.InterpretationNode;
import org.aika.corpus.Range;
import org.aika.corpus.SearchNode;
import org.aika.lattice.InputNode;
import org.aika.lattice.Node;
import org.aika.lattice.NodeActivation;
import org.aika.lattice.OrNode;
import org.aika.neuron.Activation;
import org.aika.neuron.INeuron;
import org.aika.neuron.Synapse;
import org.aika.training.SupervisedTraining;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Document
implements Comparable<Document> {
    private static final Logger log = LoggerFactory.getLogger(Document.class);
    public static boolean APPLY_DEBUG_OUTPUT = false;
    public static boolean OPTIMIZE_DEBUG_OUTPUT = false;
    public static int CLEANUP_INTERVAL = 500;
    public static int MAX_ROUND = 20;
    public final int id;
    private final String content;
    public long visitedCounter = 1L;
    public int interpretationIdCounter = 1;
    public int activationIdCounter = 0;
    public int searchNodeIdCounter = 0;
    public int searchStepCounter = 0;
    public InterpretationNode bottom = new InterpretationNode(this, -1, 0, 0);
    public Model model;
    public int threadId;
    public Queue queue = new Queue();
    public ValueQueue vQueue = new ValueQueue();
    public UpperBoundQueue ubQueue = new UpperBoundQueue();
    public TreeSet<INeuron> activatedNeurons = new TreeSet();
    public TreeSet<INeuron> finallyActivatedNeurons = new TreeSet();
    public TreeSet<Activation> inputNeuronActivations = new TreeSet();
    public TreeMap<INeuron, Set<Synapse>> modifiedWeights = new TreeMap();
    public SupervisedTraining supervisedTraining = new SupervisedTraining(this);
    public TreeMap<NodeActivation.Key, Activation> activationsByRid = new TreeMap((act1, act2) -> {
        int r = Integer.compare(act1.rid, act2.rid);
        if (r != 0) {
            return r;
        }
        return act1.compareTo((NodeActivation.Key)act2);
    });
    public TreeSet<Node> addedNodes = new TreeSet();
    public SearchNode rootSearchNode = new SearchNode(this, null, null, null, -1, Collections.emptySet(), false);
    public SearchNode selectedSearchNode = null;
    public List<InterpretationNode> rootRefs;
    public ArrayList<Candidate> candidates;
    public List<InterpretationNode> bestInterpretation = null;
    public static Comparator<NodeActivation> ACTIVATIONS_OUTPUT_COMPARATOR = (act1, act2) -> {
        int r = Range.compare(act1.key.range, act2.key.range, false);
        if (r != 0) {
            return r;
        }
        r = Utils.compareInteger(act1.key.rid, act2.key.rid);
        if (r != 0) {
            return r;
        }
        r = act1.key.interpretation.compareTo(act2.key.interpretation);
        if (r != 0) {
            return r;
        }
        return ((Node)act1.key.node).compareTo((Node)act2.key.node);
    };
    private static Comparator<Activation> VALUE_QUEUE_COMP = (a, b) -> {
        int r = Integer.compare(a.getSequence(), b.getSequence());
        if (r != 0) {
            return r;
        }
        return Integer.compare(a.id, b.id);
    };

    public Document(int id, String content, Model model, int threadId) {
        this.id = id;
        this.content = content;
        this.model = model;
        this.threadId = threadId;
    }

    public String getContent() {
        return this.content;
    }

    public int length() {
        return this.content.length();
    }

    public String toString() {
        return this.content;
    }

    public String getText(Range r) {
        return this.content.substring(Math.max(0, Math.min(r.begin, this.length())), Math.max(0, Math.min(r.end, this.length())));
    }

    public Stream<Activation> getFinalActivations() {
        return this.finallyActivatedNeurons.stream().flatMap(in -> in.getFinalActivations(this).stream());
    }

    public String bestInterpretationToString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Best Interpretation:\n");
        sb.append(this.bestInterpretation.toString());
        sb.append("\n");
        return sb.toString();
    }

    @Override
    public int compareTo(Document doc) {
        return Integer.compare(this.id, doc.id);
    }

    public void propagate() {
        boolean flag = true;
        while (flag) {
            this.queue.processChanges();
            flag = this.ubQueue.process();
        }
    }

    private void expandRootRefinement() {
        this.rootRefs = new ArrayList<InterpretationNode>();
        this.rootRefs.add(this.bottom);
        for (InterpretationNode pn : this.bottom.children) {
            if (pn.state != InterpretationNode.State.SELECTED && (!pn.isPrimitive() || !pn.conflicts.primary.isEmpty() || !pn.conflicts.secondary.isEmpty())) continue;
            this.rootRefs.add(pn);
        }
    }

    public void generateCandidates() {
        TreeSet<Candidate> tmp = new TreeSet<Candidate>();
        int i = 0;
        for (InterpretationNode cn : this.bottom.children) {
            if (cn.state != InterpretationNode.State.UNKNOWN) continue;
            tmp.add(new Candidate(cn, i++));
        }
        long v = this.visitedCounter++;
        for (InterpretationNode n : this.rootRefs) {
            Document.markCandidateSelected(n, v);
        }
        i = 0;
        this.candidates = new ArrayList();
        block2: while (!tmp.isEmpty()) {
            for (Candidate c : tmp) {
                if (!c.checkDependenciesSatisfied(v)) continue;
                tmp.remove(c);
                c.id = i++;
                this.candidates.add(c);
                Document.markCandidateSelected(c.refinement, v);
                continue block2;
            }
        }
    }

    private static void markCandidateSelected(InterpretationNode n, long v) {
        if (n.neuronActivations != null) {
            for (Activation act : n.neuronActivations) {
                act.visited = v;
            }
        }
    }

    public void process() {
        this.inputNeuronActivations.forEach(act -> this.vQueue.propagateWeight(0, (Activation)act));
        this.expandRootRefinement();
        if (OPTIMIZE_DEBUG_OUTPUT) {
            log.info("Root SearchNode:" + this.toString());
        }
        this.rootRefs.forEach(n -> n.setState(InterpretationNode.State.SELECTED, this.rootSearchNode.visited));
        this.generateCandidates();
        Candidate c = !this.candidates.isEmpty() ? this.candidates.get(0) : null;
        SearchNode child = new SearchNode(this, this.rootSearchNode, null, c, 0, this.rootRefs, false);
        SearchNode.search(this, child);
        ArrayList<InterpretationNode> results = new ArrayList<InterpretationNode>();
        results.add(this.bottom);
        if (this.selectedSearchNode != null) {
            this.selectedSearchNode.reconstructSelectedResult(this);
            this.selectedSearchNode.collectResults(results);
        }
        this.bestInterpretation = results;
        if (OPTIMIZE_DEBUG_OUTPUT) {
            this.dumpDebugCandidateStatistics();
        }
    }

    public void dumpDebugCandidateStatistics() {
        for (Candidate c : this.candidates) {
            System.out.println(c.toString());
        }
    }

    public void notifyWeightsModified(INeuron n, Collection<Synapse> inputSynapses) {
        Set<Synapse> is = this.modifiedWeights.get(n);
        if (is == null) {
            is = new TreeSet<Synapse>(Synapse.INPUT_SYNAPSE_COMP);
            this.modifiedWeights.put(n, is);
        }
        is.addAll(inputSynapses);
    }

    public void commit() {
        this.modifiedWeights.forEach((n, inputSyns) -> Converter.convert(this.model, this.threadId, n, inputSyns));
    }

    public Collection<NodeActivation> getAllNodeActivations() {
        long v = this.visitedCounter++;
        TreeSet<NodeActivation> results = new TreeSet<NodeActivation>();
        for (INeuron n : this.activatedNeurons) {
            for (Activation act : n.getAllActivations(this)) {
                this.collectNodeActivations(results, act, v);
            }
        }
        return results;
    }

    private void collectNodeActivations(Collection<NodeActivation> results, NodeActivation<?> act, long v) {
        if (act.visited == v) {
            return;
        }
        act.visited = v;
        results.add(act);
        for (NodeActivation<?> oAct : act.outputs.values()) {
            this.collectNodeActivations(results, oAct, v);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void clearActivations() {
        this.activatedNeurons.forEach(n -> n.clearActivations(this));
        this.activatedNeurons.clear();
        this.addedNodes.clear();
        if (this.model.lastCleanup[this.threadId] + CLEANUP_INTERVAL < this.id) {
            ArrayList<Provider<? extends AbstractNode>> tmp;
            this.model.lastCleanup[this.threadId] = this.id;
            Map<Integer, Provider<? extends AbstractNode>> map = this.model.activeProviders;
            synchronized (map) {
                tmp = new ArrayList<Provider<? extends AbstractNode>>(this.model.activeProviders.values());
            }
            tmp.forEach(np -> {
                Object an = np.getIfNotSuspended();
                if (an != null && an instanceof Node) {
                    Node n = (Node)an;
                    Node.ThreadState th = n.threads[this.threadId];
                    if (th != null && th.lastUsed + (long)CLEANUP_INTERVAL < (long)this.id) {
                        n.threads[this.threadId] = null;
                    }
                }
            });
        }
        this.model.docs[this.threadId] = null;
    }

    public String generateOutputText() {
        StringBuilder sb = new StringBuilder();
        this.finallyActivatedNeurons.stream().filter(n -> n.outputText != null).forEach(n -> {
            for (Activation act : n.getFinalActivations(this)) {
                sb.replace(act.key.range.begin, act.key.range.end, n.outputText);
            }
        });
        return sb.toString();
    }

    public String neuronActivationsToString(boolean withWeights) {
        return this.neuronActivationsToString(withWeights, false, false);
    }

    public String neuronActivationsToString(boolean withWeights, boolean withTextSnipped, boolean withLogic) {
        return this.neuronActivationsToString(null, withWeights, withTextSnipped, withLogic);
    }

    public String neuronActivationsToString(SearchNode sn, boolean withWeights, boolean withTextSnipped, boolean withLogic) {
        TreeSet<NodeActivation> acts = new TreeSet<NodeActivation>(ACTIVATIONS_OUTPUT_COMPARATOR);
        for (INeuron n : this.activatedNeurons) {
            Stream<Activation> stream = Activation.select(this, n, null, null, null, null, InterpretationNode.Relation.CONTAINED_IN);
            acts.addAll(stream.collect(Collectors.toList()));
        }
        StringBuilder sb = new StringBuilder();
        for (Activation activation : acts) {
            if (activation.upperBound <= 0.0 && (activation.targetValue == null || activation.targetValue <= 0.0)) continue;
            sb.append(activation.toString(sn, withWeights, withTextSnipped, withLogic));
            sb.append("\n");
        }
        if (this.selectedSearchNode != null) {
            sb.append("\n Final SearchNode:" + this.selectedSearchNode.id + "  WeightSum:" + this.selectedSearchNode.accumulatedWeight.toString() + "\n");
        }
        return sb.toString();
    }

    private void dumpOscillatingActivations() {
        this.activatedNeurons.stream().flatMap(n -> n.getAllActivations(this).stream()).filter(act -> act.rounds.getLastRound() != null && act.rounds.getLastRound() > MAX_ROUND - 5).forEach(act -> log.error(act.key + " " + (Object)((Object)act.key.interpretation.state) + " " + act.rounds));
    }

    public class ValueQueue {
        public final ArrayList<TreeSet<Activation>> queue = new ArrayList();

        public void propagateWeight(int round, Activation act) {
            for (Activation.SynapseActivation sa : act.neuronOutputs) {
                int r = sa.synapse.key.isRecurrent ? round + 1 : round;
                this.add(r, sa.output);
            }
        }

        public INeuron.NormWeight adjustWeight(SearchNode cand, Collection<InterpretationNode> changed, long visitedModified) {
            long v = Document.this.visitedCounter++;
            for (InterpretationNode n : changed) {
                this.addAllActs(n.getNeuronActivations());
            }
            return this.processChanges(cand, v, visitedModified);
        }

        private void addAllActs(Collection<Activation> acts) {
            for (Activation act : acts) {
                this.add(0, act);
                for (Activation.SynapseActivation sa : act.neuronOutputs) {
                    if (!sa.synapse.key.isRecurrent) continue;
                    this.add(0, sa.output);
                }
            }
        }

        public void add(int round, Activation act) {
            TreeSet<Activation> q;
            if (act.rounds.isQueued(round) || act.key.interpretation.state == InterpretationNode.State.UNKNOWN) {
                return;
            }
            if (round < this.queue.size()) {
                q = this.queue.get(round);
            } else {
                assert (round == this.queue.size());
                q = new TreeSet(VALUE_QUEUE_COMP);
                this.queue.add(q);
            }
            act.rounds.setQueued(round, true);
            q.add(act);
        }

        public INeuron.NormWeight processChanges(SearchNode sn, long v, long visitedModified) {
            INeuron.NormWeight delta = INeuron.NormWeight.ZERO_WEIGHT;
            for (int round = 0; round < this.queue.size(); ++round) {
                TreeSet<Activation> q = this.queue.get(round);
                while (!q.isEmpty()) {
                    Activation act = q.pollFirst();
                    act.rounds.setQueued(round, false);
                    Activation.State s = act.inputValue != null ? new Activation.State(act.inputValue, 0, INeuron.NormWeight.ZERO_WEIGHT) : ((INeuron)((OrNode)act.key.node).neuron.get(Document.this)).computeWeight(round, act);
                    if (OPTIMIZE_DEBUG_OUTPUT) {
                        log.info(act.key + " Round:" + round);
                        log.info("Value:" + s.value + "  Weight:" + s.weight.w + "  Norm:" + s.weight.n + "\n");
                    }
                    if (round != 0 && act.rounds.get(round).equalsWithWeights(s)) continue;
                    act.saveOldState(sn.modifiedActs, v);
                    Activation.State oldState = act.rounds.get(round);
                    boolean propagate = act.rounds.set(round, s) && (oldState == null || !oldState.equals(s));
                    act.rounds.modified = visitedModified;
                    act.saveNewState();
                    if (propagate) {
                        if (round > MAX_ROUND) {
                            log.error("Error: Maximum number of rounds reached. The network might be oscillating.");
                            log.info(Document.this.neuronActivationsToString(sn, true, true, true));
                            Document.this.dumpOscillatingActivations();
                            throw new RuntimeException("Maximum number of rounds reached. The network might be oscillating.");
                        }
                        this.propagateWeight(round, act);
                    }
                    if (round == 0) {
                        this.add(1, act);
                    }
                    if (act.rounds.getLastRound() == null || round < act.rounds.getLastRound()) continue;
                    delta = delta.add(s.weight.sub(oldState.weight));
                }
            }
            return delta;
        }
    }

    public class UpperBoundQueue {
        public final ArrayDeque<Activation> queue = new ArrayDeque();

        public void add(Activation act) {
            if (!act.ubQueued) {
                act.ubQueued = true;
                this.queue.addLast(act);
            }
        }

        public boolean process() {
            boolean flag = false;
            while (!this.queue.isEmpty()) {
                flag = true;
                Activation act = this.queue.pollFirst();
                act.ubQueued = false;
                double oldUpperBound = act.upperBound;
                INeuron n = (INeuron)((OrNode)act.key.node).neuron.get(Document.this);
                if (act.inputValue == null) {
                    n.computeBounds(act);
                } else {
                    act.upperBound = act.inputValue;
                    act.lowerBound = act.inputValue;
                }
                if (Math.abs(act.upperBound - oldUpperBound) > 0.01) {
                    for (Activation.SynapseActivation synapseActivation : act.neuronOutputs) {
                        this.add(synapseActivation.output);
                    }
                }
                if (!(oldUpperBound <= 0.0) || !(act.upperBound > 0.0)) continue;
                for (Provider provider : n.outputNodes.values()) {
                    ((InputNode)provider.get(Document.this)).addActivation(Document.this, act);
                }
            }
            return flag;
        }
    }

    public class Queue {
        public final TreeSet<Node> queue = new TreeSet<Node>(new Comparator<Node>(){

            @Override
            public int compare(Node n1, Node n2) {
                int r = Integer.compare(n1.level, n2.level);
                if (r != 0) {
                    return r;
                }
                Node.ThreadState th1 = n1.getThreadState(Document.this.threadId, true);
                Node.ThreadState th2 = n2.getThreadState(Document.this.threadId, true);
                return Long.compare(th1.queueId, th2.queueId);
            }
        });
        private long queueIdCounter = 0L;

        public void add(Node n) {
            Node.ThreadState th = n.getThreadState(Document.this.threadId, true);
            if (!th.isQueued) {
                th.isQueued = true;
                th.queueId = this.queueIdCounter++;
                this.queue.add(n);
            }
        }

        public void processChanges() {
            while (!this.queue.isEmpty()) {
                Node n = this.queue.pollFirst();
                Node.ThreadState th = n.getThreadState(Document.this.threadId, true);
                th.isQueued = false;
                n.processChanges(Document.this);
                if (!APPLY_DEBUG_OUTPUT) continue;
                log.info("QueueId:" + th.queueId);
                log.info(n.toString() + "\n");
                log.info("\n" + Document.this.neuronActivationsToString(true, true, true));
            }
        }
    }
}

