/*
 * Decompiled with CFR 0.152.
 */
package org.aika.training;

import java.util.Collections;
import java.util.TreeSet;
import org.aika.corpus.Document;
import org.aika.neuron.Activation;
import org.aika.neuron.INeuron;
import org.aika.neuron.Synapse;
import org.aika.training.SynapseEvaluation;

public class LongTermLearning {
    public static void train(Document doc, Config config) {
        doc.getFinalActivations().forEach(act -> {
            LongTermLearning.longTermPotentiation(doc, config, act);
            LongTermLearning.longTermDepression(doc, config, act, false);
            LongTermLearning.longTermDepression(doc, config, act, true);
        });
        doc.commit();
    }

    public static void longTermPotentiation(Document doc, Config config, Activation act) {
        INeuron n = act.getINeuron();
        double maxActValue = n.activationFunction.f(n.biasSum + n.posDirSum + n.posRecSum);
        double m = maxActValue > 0.0 ? Math.max(1.0, act.getFinalState().value / maxActValue) : 1.0;
        double x = config.ltpLearnRate * (1.0 - act.getFinalState().value) * m;
        if (config.createNewSynapses) {
            doc.getFinalActivations().filter(iAct -> iAct.key.node != activation.key.node).forEach(iAct -> {
                SynapseEvaluation.Result r = config.synapseEvaluation.evaluate(null, (Activation)iAct, act);
                LongTermLearning.synapseLTP(config, iAct, act, x, r);
            });
        } else {
            act.getFinalInputActivations().forEach(sa -> {
                SynapseEvaluation.Result r = config.synapseEvaluation.evaluate(sa.synapse, sa.input, act);
                LongTermLearning.synapseLTP(config, sa.input, act, x, r);
            });
        }
        doc.notifyWeightsModified(n, n.inputSynapses.values());
    }

    private static void synapseLTP(Config config, Activation iAct, Activation act, double x, SynapseEvaluation.Result r) {
        if (r == null) {
            return;
        }
        double sDelta = iAct.getFinalState().value * x * r.significance;
        if (sDelta > 0.0) {
            Synapse synapse = Synapse.createOrLookup(r.synapseKey, iAct.getNeuron(), act.getNeuron());
            synapse.weightDelta += (double)((float)sDelta);
            synapse.changeBias(-config.beta * sDelta);
        }
    }

    public static void longTermDepression(Document doc, Config config, Activation act, boolean dir) {
        if (act.getFinalState().value <= 0.0) {
            return;
        }
        INeuron n = act.getINeuron();
        TreeSet<Synapse> actSyns = new TreeSet<Synapse>(dir ? Synapse.OUTPUT_SYNAPSE_COMP : Synapse.INPUT_SYNAPSE_COMP);
        (dir ? act.getFinalOutputActivations() : act.getFinalInputActivations()).forEach(sa -> actSyns.add(sa.synapse));
        (dir ? n.outputSynapses : n.inputSynapses).values().stream().filter(s -> !s.isNegative() && !actSyns.contains(s)).forEach(s -> {
            SynapseEvaluation.Result r;
            if (s.isConjunction(false) != dir && (r = config.synapseEvaluation.evaluate((Synapse)s, dir ? act : null, dir ? null : act)) != null) {
                s.weightDelta -= (double)((float)(config.ltdLearnRate * activation.getFinalState().value * r.significance));
                r.deleteMode.checkIfDelete((Synapse)s);
                if (dir) {
                    doc.notifyWeightsModified((INeuron)s.output.get(), Collections.singletonList(s));
                }
            }
        });
        if (!dir) {
            doc.notifyWeightsModified(n, n.inputSynapses.values());
        }
    }

    public static class Config {
        public SynapseEvaluation synapseEvaluation;
        public double ltpLearnRate;
        public double ltdLearnRate;
        public double beta;
        public boolean createNewSynapses;

        public Config setSynapseEvaluation(SynapseEvaluation synapseEvaluation) {
            this.synapseEvaluation = synapseEvaluation;
            return this;
        }

        public Config setLTPLearnRate(double learnRate) {
            this.ltpLearnRate = learnRate;
            return this;
        }

        public Config setLTDLearnRate(double learnRate) {
            this.ltdLearnRate = learnRate;
            return this;
        }

        public Config setBeta(double beta) {
            this.beta = beta;
            return this;
        }

        public Config setCreateNewSynapses(boolean createNewSynapses) {
            this.createNewSynapses = createNewSynapses;
            return this;
        }
    }
}

