/*
 * Decompiled with CFR 0.152.
 */
package org.aika.training;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.aika.Document;
import org.aika.Utils;
import org.aika.lattice.OrNode;
import org.aika.neuron.INeuron;
import org.aika.neuron.Neuron;
import org.aika.neuron.Synapse;
import org.aika.neuron.activation.Activation;
import org.aika.training.MetaSynapse;
import org.aika.training.NeuronStatistic;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MetaNetwork {
    private static final Logger log = LoggerFactory.getLogger(MetaNetwork.class);

    public static void train(Document doc) {
        long v = doc.visitedCounter++;
        TreeMap<Activation, List<Target>> metaActivations = new TreeMap<Activation, List<Target>>();
        List inhibitoryNeurons = doc.finallyActivatedNeurons.stream().filter(n -> n.type == INeuron.Type.INHIBITORY).collect(Collectors.toList());
        for (INeuron iNeuron : inhibitoryNeurons) {
            for (Activation inhibAct : iNeuron.getFinalActivations(doc)) {
                for (Activation.SynapseActivation sa : inhibAct.getFinalInputActivations()) {
                    Activation metaNeuronAct;
                    Activation act = sa.input;
                    Neuron targetNeuron = act.getNeuron();
                    ++doc.visitedCounter;
                    doc.createV = doc.createV;
                    boolean newNeuron = false;
                    if (((INeuron)targetNeuron.get()).type == INeuron.Type.META) {
                        newNeuron = true;
                        targetNeuron = doc.model.createNeuron(iNeuron.label.substring(2) + "-" + doc.getText(act.key.range));
                        INeuron.update(doc.model, doc.threadId, doc, targetNeuron, iNeuron.bias, Collections.emptySet());
                    }
                    if ((metaNeuronAct = MetaNetwork.getMetaNeuronAct(inhibAct)) == null) continue;
                    ArrayList<Target> targets = (ArrayList<Target>)metaActivations.get(metaNeuronAct);
                    if (targets == null) {
                        targets = new ArrayList<Target>();
                        metaActivations.put(metaNeuronAct, targets);
                    }
                    targets.add(new Target(targetNeuron, newNeuron, (Neuron)iNeuron.provider));
                }
            }
        }
        for (Map.Entry entry : metaActivations.entrySet()) {
            for (Target t : (List)entry.getValue()) {
                MetaNetwork.transferMetaSynapses(doc, metaActivations, (Activation)entry.getKey(), t, v);
            }
        }
    }

    private static Activation getMetaNeuronAct(Activation inhibAct) {
        for (Activation.SynapseActivation sa : inhibAct.neuronInputs) {
            if (sa.input.getINeuron().type != INeuron.Type.META) continue;
            return sa.input;
        }
        return null;
    }

    private static void transferMetaSynapses(Document doc, Map<Activation, List<Target>> metaActivations, Activation metaAct, Target t, long v) {
        TreeSet<Synapse> inputSynapses = new TreeSet<Synapse>(Synapse.INPUT_SYNAPSE_COMP);
        Integer ridOffset = MetaNetwork.computeRidOffset(metaAct);
        for (Activation.SynapseActivation sa : metaAct.getFinalInputActivations()) {
            MetaSynapse inputMetaSyanpse = sa.synapse.meta;
            Synapse.Key osk = sa.synapse.key;
            if (inputMetaSyanpse == null || inputMetaSyanpse.metaWeight == 0.0 && inputMetaSyanpse.metaBias == 0.0) continue;
            Neuron ina = ((OrNode)sa.input.key.node).neuron;
            List<Activation.SynapseActivation> inputs = ((INeuron)ina.get()).type == INeuron.Type.INHIBITORY && inputMetaSyanpse.metaWeight >= 0.0 ? sa.input.getFinalInputActivations() : Collections.singletonList(sa);
            for (Activation.SynapseActivation isa : inputs) {
                Neuron in = isa.input.getNeuron();
                Integer rid = isa.input.key.rid;
                Integer nRid = Utils.nullSafeSub(rid, false, ridOffset, false);
                if (((INeuron)in.get((Document)doc)).type == INeuron.Type.META) {
                    List<Target> inputTargets = metaActivations.get(isa.input);
                    if (inputTargets == null) continue;
                    for (Target it : metaActivations.get(isa.input)) {
                        MetaNetwork.createOrLookupSynapse(doc, t, inputSynapses, inputMetaSyanpse, osk, nRid, it.targetNeuron);
                    }
                    continue;
                }
                MetaNetwork.createOrLookupSynapse(doc, t, inputSynapses, inputMetaSyanpse, osk, nRid, in);
            }
        }
        if (log.isDebugEnabled()) {
            log.debug(MetaNetwork.showDelta((INeuron)t.targetNeuron.get(), inputSynapses));
        }
        INeuron.update(doc.model, doc.threadId, doc, t.targetNeuron, t.isNewNeuron ? metaAct.getINeuron().metaBias : 0.0, inputSynapses);
        if (t.isNewNeuron) {
            Activation.SynapseActivation inhibMetaLink = metaAct.getFinalOutputActivations().get(0);
            Synapse.Key inhibSynKey = inhibMetaLink.synapse.key;
            MetaSynapse inhibSS = inhibMetaLink.synapse.meta;
            t.inhibNeuron.addSynapse(new Synapse.Builder().setNeuron(t.targetNeuron).setWeight(inhibSS.metaWeight).setBias(inhibSS.metaBias).setRelativeRid(inhibSynKey.relativeRid).setAbsoluteRid(inhibSynKey.absoluteRid).setRangeMatch(inhibSynKey.rangeMatch).setRangeOutput(inhibSynKey.rangeOutput));
        }
        doc.propagate();
    }

    private static void createOrLookupSynapse(Document doc, Target t, TreeSet<Synapse> inputSynapses, MetaSynapse inputMetaSyanpse, Synapse.Key osk, Integer nRid, Neuron in) {
        Synapse.Key nsk = new Synapse.Key(osk.isRecurrent, osk.relativeRid != null ? osk.relativeRid : (inputMetaSyanpse.metaRelativeRid ? nRid : null), osk.absoluteRid, osk.rangeMatch, osk.rangeOutput);
        Synapse ns = new Synapse(in, t.targetNeuron, nsk);
        if (!ns.exists()) {
            ns.updateDelta(doc, inputMetaSyanpse.metaWeight, inputMetaSyanpse.metaBias);
            inputSynapses.add(ns);
        }
    }

    public static String showDelta(INeuron n, Set<Synapse> synapses) {
        StringBuilder sb = new StringBuilder();
        sb.append("N: " + n.label + " ob:" + n.biasSum + " nb:" + (n.biasSum + n.biasSumDelta) + "\n");
        for (Synapse s : synapses) {
            if (s.weightDelta == 0.0) continue;
            Integer f = null;
            if (((INeuron)s.input.get()).statistic != null) {
                f = ((NeuronStatistic)((INeuron)s.input.get()).statistic).frequency;
            }
            sb.append("    S:" + s.input.getLabel() + " ow:" + s.weight + " nw:" + s.getNewWeight() + (f != null ? " f:" + f : "") + " " + (s.isConjunction(false, false) ? "CONJ" : "DISJ") + "\n");
        }
        return sb.toString();
    }

    private static Integer computeRidOffset(Activation mAct) {
        for (Activation.SynapseActivation sa : mAct.getFinalInputActivations()) {
            if (sa.synapse.key.relativeRid == null || sa.input.key.rid == null) continue;
            return sa.input.key.rid - sa.synapse.key.relativeRid;
        }
        return null;
    }

    public static Neuron initMetaNeuron(Neuron n, double bias, double metaBias, Synapse.Builder ... inputs) {
        ((INeuron)n.get()).metaBias = metaBias;
        return Neuron.init(n, bias, INeuron.Type.META, inputs);
    }

    private static class Target {
        Neuron targetNeuron;
        boolean isNewNeuron;
        Neuron inhibNeuron;

        public Target(Neuron targetNeuron, boolean isNewNeuron, Neuron inhibNeuron) {
            this.targetNeuron = targetNeuron;
            this.isNewNeuron = isNewNeuron;
            this.inhibNeuron = inhibNeuron;
        }
    }
}

