/*
 * Decompiled with CFR 0.152.
 */
package org.aika.training;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.TreeSet;
import org.aika.Model;
import org.aika.Neuron;
import org.aika.corpus.Conflicts;
import org.aika.corpus.Document;
import org.aika.corpus.InterpretationNode;
import org.aika.lattice.Node;
import org.aika.lattice.NodeActivation;
import org.aika.lattice.OrNode;
import org.aika.neuron.Activation;
import org.aika.neuron.INeuron;
import org.aika.neuron.Synapse;
import org.aika.training.MetaSynapse;

public class MetaNetwork {
    public static void train(Document doc) {
        long v = doc.visitedCounter++;
        for (INeuron n : new ArrayList<INeuron>(doc.finallyActivatedNeurons)) {
            if (n.type != INeuron.Type.INHIBITORY) continue;
            for (Activation sAct : n.getFinalActivations(doc)) {
                for (Activation.SynapseActivation sa : sAct.getFinalInputActivations()) {
                    Activation metaNeuronAct;
                    Activation act = sa.input;
                    Neuron targetNeuron = ((OrNode)act.key.node).neuron;
                    boolean newNeuron = false;
                    if (((INeuron)targetNeuron.get()).type == INeuron.Type.META) {
                        newNeuron = true;
                        targetNeuron = doc.model.createNeuron(n.label.substring(2) + "-" + doc.getText(act.key.range));
                        INeuron.update(doc.model, doc.threadId, targetNeuron, n.bias, Collections.emptySet());
                    }
                    if ((metaNeuronAct = MetaNetwork.getMetaNeuronAct(sAct)) == null || metaNeuronAct.visited == v) continue;
                    metaNeuronAct.visited = v;
                    MetaNetwork.transferMetaSynapses(doc, metaNeuronAct, targetNeuron, newNeuron, (Neuron)n.provider, v);
                }
            }
        }
    }

    private static Activation getMetaNeuronAct(Activation sAct) {
        for (Activation.SynapseActivation sa : sAct.neuronInputs) {
            if (!((INeuron)((OrNode)sa.input.key.node).neuron.get()).label.startsWith("M-")) continue;
            return sa.input;
        }
        return null;
    }

    private static void transferMetaSynapses(Document doc, Activation metaAct, Neuron targetNeuron, boolean newNeuron, Neuron supprN, long v) {
        TreeSet<Synapse> inputSynapses = new TreeSet<Synapse>(Synapse.INPUT_SYNAPSE_COMP);
        Integer ridOffset = MetaNetwork.computeRidOffset(metaAct);
        for (Activation.SynapseActivation sa : metaAct.getFinalInputActivations()) {
            MetaSynapse ss = sa.synapse.meta;
            if (ss == null || ss.metaWeight == 0.0 && ss.metaBias == 0.0) continue;
            Neuron ina = ((OrNode)sa.input.key.node).neuron;
            Neuron inb = null;
            Integer rid = null;
            if (((INeuron)ina.get()).type == INeuron.Type.INHIBITORY && ss.metaWeight >= 0.0) {
                List<Activation.SynapseActivation> inputs = sa.input.getFinalInputActivations();
                for (Activation.SynapseActivation iSA : inputs) {
                    Activation iAct = iSA.input;
                    inb = ((OrNode)iAct.key.node).neuron;
                    rid = iAct.key.rid;
                }
            } else {
                inb = ina;
                rid = sa.input.key.rid;
            }
            if (inb == null) continue;
            Synapse.Key osk = sa.synapse.key;
            Synapse.Key nsk = new Synapse.Key(osk.isRecurrent, osk.relativeRid != null ? osk.relativeRid : (ss.metaRelativeRid && ridOffset != null && rid != null ? Integer.valueOf(rid - ridOffset) : null), osk.absoluteRid, osk.rangeMatch, osk.rangeOutput);
            Synapse ns = new Synapse(inb, targetNeuron, nsk);
            if (ns.exists()) continue;
            ns.weightDelta = ss.metaWeight;
            ns.biasDelta = ss.metaBias;
            inputSynapses.add(ns);
        }
        INeuron.update(doc.model, doc.threadId, targetNeuron, newNeuron ? ((INeuron)((OrNode)metaAct.key.node).neuron.get()).metaBias : 0.0, inputSynapses);
        if (newNeuron) {
            Activation.SynapseActivation inhibMetaLink = metaAct.getFinalOutputActivations().get(0);
            Synapse.Key inhibSynKey = inhibMetaLink.synapse.key;
            MetaSynapse inhibSS = inhibMetaLink.synapse.meta;
            supprN.addSynapse(new Synapse.Builder().setNeuron(targetNeuron).setWeight(inhibSS.metaWeight).setBias(inhibSS.metaBias).setRelativeRid(inhibSynKey.relativeRid).setAbsoluteRid(inhibSynKey.absoluteRid).setRangeMatch(inhibSynKey.rangeMatch).setRangeOutput(inhibSynKey.rangeOutput));
            NodeActivation.Key mak = metaAct.key;
            mak.interpretation.setState(InterpretationNode.State.EXCLUDED, v);
        }
        for (Synapse s : inputSynapses) {
            for (Activation iAct : ((INeuron)s.input.get()).getFinalActivations(doc)) {
                iAct.upperBound = 0.0;
                MetaNetwork.repropagate(doc, iAct);
            }
        }
        doc.propagate();
        for (Activation tAct : ((INeuron)targetNeuron.get()).getAllActivations(doc)) {
            if (MetaNetwork.isConflicting(tAct, doc.visitedCounter++)) continue;
            tAct.key.interpretation.setState(InterpretationNode.State.SELECTED, v);
            Activation sAct = MetaNetwork.getOutputAct(tAct.neuronOutputs, INeuron.Type.INHIBITORY);
            sAct.key.interpretation.setState(InterpretationNode.State.SELECTED, v);
            Activation mAct = MetaNetwork.getOutputAct(sAct.neuronOutputs, INeuron.Type.META);
            ArrayList<Activation> newActs = new ArrayList<Activation>();
            if (mAct != null) {
                mAct.visited = v;
                newActs.add(mAct);
            }
            newActs.add(tAct);
            newActs.add(sAct);
            newActs.forEach(act -> document.vQueue.add(0, (Activation)act));
            doc.vQueue.processChanges(doc.selectedSearchNode, doc.visitedCounter++, doc.selectedSearchNode.visited);
            if (tAct.getFinalState().value <= 0.0) {
                tAct.key.interpretation.setState(InterpretationNode.State.EXCLUDED, v);
                mAct.key.interpretation.setState(InterpretationNode.State.SELECTED, doc.selectedSearchNode.visited);
                newActs.forEach(act -> document.vQueue.add(0, (Activation)act));
                doc.vQueue.processChanges(doc.selectedSearchNode, doc.visitedCounter++, doc.selectedSearchNode.visited);
            }
            for (Activation act2 : newActs) {
                if (!act2.isFinalActivation()) continue;
                doc.finallyActivatedNeurons.add((INeuron)((OrNode)act2.key.node).neuron.get(doc));
            }
        }
    }

    private static boolean isConflicting(Activation tAct, long v) {
        ArrayList<InterpretationNode> tmp = new ArrayList<InterpretationNode>();
        Conflicts.collectConflicting(tmp, tAct.key.interpretation, v);
        for (InterpretationNode c : tmp) {
            if (c.state != InterpretationNode.State.SELECTED) continue;
            return true;
        }
        return false;
    }

    private static Activation getOutputAct(TreeSet<Activation.SynapseActivation> outputActs, INeuron.Type type) {
        for (Activation.SynapseActivation sa : outputActs) {
            if (((INeuron)((OrNode)sa.output.key.node).neuron.get()).type != type) continue;
            return sa.output;
        }
        return null;
    }

    private static void repropagate(Document doc, NodeActivation<?> act) {
        ((Node)act.key.node).propagateAddedActivation(doc, act);
        for (NodeActivation<?> oAct : act.outputs.values()) {
            if (oAct instanceof Activation) continue;
            MetaNetwork.repropagate(doc, oAct);
        }
    }

    private static Integer computeRidOffset(Activation mAct) {
        for (Activation.SynapseActivation sa : mAct.getFinalInputActivations()) {
            if (sa.synapse.key.relativeRid == null || sa.input.key.rid == null) continue;
            return sa.input.key.rid - sa.synapse.key.relativeRid;
        }
        return null;
    }

    public static Neuron initMetaNeuron(Model m, Neuron n, double bias, double metaBias, Synapse.Builder ... inputs) {
        ((INeuron)n.get()).metaBias = metaBias;
        return Neuron.init(n, bias, INeuron.Type.META, inputs);
    }
}

