/*
 * Decompiled with CFR 0.152.
 */
package org.aika.training;

import java.util.TreeSet;
import org.aika.Document;
import org.aika.Utils;
import org.aika.neuron.INeuron;
import org.aika.neuron.Synapse;
import org.aika.neuron.activation.Activation;
import org.aika.training.SynapseEvaluation;

public class LongTermLearning {
    public static void train(Document doc, Config config) {
        doc.getActivations().filter(act -> act.targetValue == null ? act.isFinalActivation() : act.targetValue > 0.0).forEach(act -> {
            LongTermLearning.longTermPotentiation(doc, config, act);
            LongTermLearning.longTermDepression(doc, config, act, false);
            LongTermLearning.longTermDepression(doc, config, act, true);
        });
    }

    private static double hConj(Activation act) {
        INeuron n = act.getINeuron();
        return act.getFinalState().net / (n.biasSum + n.posDirSum + n.posRecSum);
    }

    public static void longTermPotentiation(Document doc, Config config, Activation act) {
        INeuron n = act.getINeuron();
        double iv = Utils.nullSafeMax(act.getFinalState().value, act.targetValue);
        double x = config.ltpLearnRate * (1.0 - act.getFinalState().value) * iv;
        if (config.createNewSynapses) {
            doc.getActivations().filter(iAct -> iAct.targetValue == null ? iAct.isFinalActivation() : iAct.targetValue > 0.0).filter(iAct -> iAct.key.node != activation.key.node).forEach(iAct -> LongTermLearning.synapseLTP(config, null, iAct, act, x));
        } else {
            act.neuronInputs.stream().filter(sa -> sa.input.targetValue == null ? sa.input.isFinalActivation() : sa.input.targetValue > 0.0).forEach(sa -> LongTermLearning.synapseLTP(config, sa.synapse, sa.input, act, x));
        }
    }

    private static void synapseLTP(Config config, Synapse s, Activation iAct, Activation act, double x) {
        SynapseEvaluation.Result r = config.synapseEvaluation.evaluate(s, iAct, act);
        if (r == null) {
            return;
        }
        double h = s.isConjunction(false, false) ? LongTermLearning.hConj(act) : 1.0;
        double sDelta = iAct.getFinalState().value * x * r.significance * h;
        if (sDelta > 0.0) {
            Synapse synapse = Synapse.createOrLookup(act.doc, r.synapseKey, iAct.getNeuron(), act.getNeuron());
            synapse.updateDelta(act.doc, sDelta, -config.beta * sDelta);
        }
    }

    public static void longTermDepression(Document doc, Config config, Activation act, boolean dir) {
        if (act.getFinalState().value <= 0.0) {
            return;
        }
        INeuron n = act.getINeuron();
        TreeSet<Synapse> actSyns = new TreeSet<Synapse>(dir ? Synapse.OUTPUT_SYNAPSE_COMP : Synapse.INPUT_SYNAPSE_COMP);
        (dir ? act.neuronOutputs : act.neuronInputs).forEach(sa -> {
            Activation rAct;
            Activation activation = rAct = dir ? sa.output : sa.input;
            if (rAct.targetValue == null ? rAct.isFinalActivation() : rAct.targetValue > 0.0) {
                actSyns.add(sa.synapse);
            }
        });
        (dir ? n.outputSynapses : n.inputSynapses).values().stream().filter(s -> !s.isNegative() && !actSyns.contains(s)).forEach(s -> {
            SynapseEvaluation.Result r;
            if (s.isConjunction(false, false) != dir && (r = config.synapseEvaluation.evaluate((Synapse)s, dir ? act : null, dir ? null : act)) != null) {
                s.updateDelta(doc, -config.ltdLearnRate * activation.getFinalState().value * r.significance, 0.0);
                r.deleteMode.checkIfDelete((Synapse)s, false);
            }
        });
    }

    public static class Config {
        public SynapseEvaluation synapseEvaluation;
        public double ltpLearnRate;
        public double ltdLearnRate;
        public double beta;
        public boolean createNewSynapses;

        public Config setSynapseEvaluation(SynapseEvaluation synapseEvaluation) {
            this.synapseEvaluation = synapseEvaluation;
            return this;
        }

        public Config setLTPLearnRate(double learnRate) {
            this.ltpLearnRate = learnRate;
            return this;
        }

        public Config setLTDLearnRate(double learnRate) {
            this.ltdLearnRate = learnRate;
            return this;
        }

        public Config setBeta(double beta) {
            this.beta = beta;
            return this;
        }

        public Config setCreateNewSynapses(boolean createNewSynapses) {
            this.createNewSynapses = createNewSynapses;
            return this;
        }
    }
}

