/*
 * Decompiled with CFR 0.152.
 */
package org.aika.training;

import java.util.Collections;
import java.util.TreeSet;
import org.aika.corpus.Document;
import org.aika.lattice.OrNode;
import org.aika.neuron.Activation;
import org.aika.neuron.INeuron;
import org.aika.neuron.Synapse;
import org.aika.training.SynapseEvaluation;

public class LongTermLearning {
    public static void train(Document doc, Config config) {
        doc.getFinalActivations().forEach(act -> {
            LongTermLearning.longTermPotentiation(doc, config, act);
            LongTermLearning.longTermDepression(doc, config, act, false);
            LongTermLearning.longTermDepression(doc, config, act, true);
        });
    }

    public static void longTermPotentiation(Document doc, Config config, Activation act) {
        INeuron n = (INeuron)((OrNode)act.key.node).neuron.get();
        double iaSum = 0.0;
        for (Activation.SynapseActivation sa : act.getFinalInputActivations()) {
            if (sa.synapse.isNegative()) continue;
            iaSum += sa.input.getFinalState().value * sa.synapse.weight;
        }
        double norm = n.posDirSum + n.posRecSum > 0.0 ? iaSum / (n.posDirSum + n.posRecSum) : 1.0;
        double x = config.ltpLearnRate * (1.0 - act.getFinalState().value) * norm;
        doc.getFinalActivations().filter(iAct -> iAct.key.node != activation.key.node).forEach(iAct -> {
            SynapseEvaluation.Result r = config.synapseEvaluation.evaluate(null, (Activation)iAct, act);
            double sDelta = iAct.getFinalState().value * x * r.significance;
            if (sDelta > 0.0) {
                Synapse synapse = Synapse.createOrLookup(r.synapseKey, ((OrNode)iAct.key.node).neuron, ((OrNode)activation.key.node).neuron);
                synapse.weightDelta += (double)((float)sDelta);
                synapse.biasDelta -= config.beta * sDelta;
                assert (!Double.isNaN(iNeuron.bias));
            }
        });
        doc.notifyWeightsModified(n, n.inputSynapses.values());
    }

    public static void longTermDepression(Document doc, Config config, Activation act, boolean dir) {
        INeuron n = (INeuron)((OrNode)act.key.node).neuron.get();
        TreeSet<Synapse> actSyns = new TreeSet<Synapse>(dir ? Synapse.OUTPUT_SYNAPSE_COMP : Synapse.INPUT_SYNAPSE_COMP);
        (dir ? act.getFinalOutputActivations() : act.getFinalInputActivations()).forEach(sa -> actSyns.add(sa.synapse));
        (dir ? n.outputSynapses : n.inputSynapses).values().stream().filter(s -> !s.isNegative() && !actSyns.contains(s)).forEach(s -> {
            SynapseEvaluation.Result r = config.synapseEvaluation.evaluate((Synapse)s, dir ? act : null, dir ? null : act);
            s.weightDelta -= (double)((float)(config.ltdLearnRate * activation.getFinalState().value * r.significance));
            if (r.deleteIfNull && s.weight - s.weightDelta <= 0.0) {
                s.toBeDeleted = true;
            }
            if (dir) {
                doc.notifyWeightsModified((INeuron)s.output.get(), Collections.singletonList(s));
            }
        });
        if (!dir) {
            doc.notifyWeightsModified(n, n.inputSynapses.values());
        }
    }

    public static class Config {
        public SynapseEvaluation synapseEvaluation;
        public double ltpLearnRate;
        public double ltdLearnRate;
        public double beta;

        public Config setSynapseEvaluation(SynapseEvaluation synapseEvaluation) {
            this.synapseEvaluation = synapseEvaluation;
            return this;
        }

        public Config setLTPLearnRate(double learnRate) {
            this.ltpLearnRate = learnRate;
            return this;
        }

        public Config setLTDLearnRate(double learnRate) {
            this.ltdLearnRate = learnRate;
            return this;
        }

        public Config setBeta(double beta) {
            this.beta = beta;
            return this;
        }
    }
}

