/*
 * Decompiled with CFR 0.152.
 */
package org.aika.training;

import java.util.TreeSet;
import org.aika.Neuron;
import org.aika.corpus.Document;
import org.aika.lattice.OrNode;
import org.aika.neuron.Activation;
import org.aika.neuron.INeuron;
import org.aika.neuron.Synapse;
import org.aika.training.SynapseEvaluation;

public class SupervisedTraining {
    public Document doc;
    public TreeSet<Activation> targetActivations = new TreeSet();
    public TreeSet<Activation> errorSignalActivations = new TreeSet();
    public BackPropagationQueue queue = new BackPropagationQueue();

    public SupervisedTraining(Document doc) {
        this.doc = doc;
    }

    public void train(Config config) {
        this.targetActivations.forEach(tAct -> this.computeOutputErrorSignal((Activation)tAct));
        if (config.performBackpropagation) {
            this.queue.backpropagtion();
        }
        for (Activation act : this.errorSignalActivations) {
            this.train((INeuron)((OrNode)act.key.node).neuron.get(this.doc), act, config.learnRate, config.synapseEvaluation);
        }
        this.errorSignalActivations.clear();
    }

    public void train(INeuron n, Activation targetAct, double learnRate, SynapseEvaluation se) {
        if (Math.abs(targetAct.errorSignal) < INeuron.TOLERANCE) {
            return;
        }
        long v = this.doc.visitedCounter++;
        double x = learnRate * targetAct.errorSignal;
        n.biasDelta += x;
        this.doc.getFinalActivations().forEach(iAct -> {
            SynapseEvaluation.Result ser = se.evaluate(null, (Activation)iAct, targetAct);
            if (ser != null) {
                this.trainSynapse(n, (Activation)iAct, ser, x, v);
            }
        });
        this.doc.notifyWeightsModified(n, ((Neuron)n.provider).inMemoryInputSynapses.values());
    }

    public void computeOutputErrorSignal(Activation act) {
        if (act.targetValue != null) {
            act.errorSignal += act.targetValue - act.getFinalState().value;
        }
        this.updateErrorSignal(act);
    }

    public void computeBackpropagationErrorSignal(Activation act) {
        for (Activation.SynapseActivation sa : act.neuronOutputs) {
            Synapse s = sa.synapse;
            Activation oAct = sa.output;
            act.errorSignal += s.weight * oAct.errorSignal * (1.0 - act.getFinalState().value);
        }
        this.updateErrorSignal(act);
    }

    public void updateErrorSignal(Activation act) {
        if (act.errorSignal != 0.0) {
            this.errorSignalActivations.add(act);
            for (Activation.SynapseActivation sa : act.neuronInputs) {
                this.queue.add(sa.input);
            }
        }
    }

    private void trainSynapse(INeuron n, Activation iAct, SynapseEvaluation.Result ser, double x, long v) {
        if (iAct.visited == v) {
            return;
        }
        iAct.visited = v;
        INeuron inputNeuron = (INeuron)((OrNode)iAct.key.node).neuron.get(this.doc);
        if (inputNeuron == n) {
            return;
        }
        double deltaW = x * ser.significance * iAct.getFinalState().value;
        Synapse synapse = Synapse.createOrLookup(ser.synapseKey, (Neuron)inputNeuron.provider, (Neuron)n.provider);
        synapse.weightDelta = (float)deltaW;
    }

    public class BackPropagationQueue {
        public final TreeSet<Activation> queue = new TreeSet((act1, act2) -> {
            Activation.State fs1 = act1.getFinalState();
            Activation.State fs2 = act2.getFinalState();
            int r = Integer.compare(fs2.fired, fs1.fired);
            if (r != 0) {
                return r;
            }
            return act1.key.compareTo(act2.key);
        });
        private long queueIdCounter = 0L;

        public void add(Activation act) {
            if (!act.isQueued) {
                act.isQueued = true;
                act.queueId = this.queueIdCounter++;
                this.queue.add(act);
            }
        }

        public void backpropagtion() {
            while (!this.queue.isEmpty()) {
                Activation act = this.queue.pollFirst();
                act.isQueued = false;
                SupervisedTraining.this.computeBackpropagationErrorSignal(act);
            }
        }
    }

    public static class Config {
        public SynapseEvaluation synapseEvaluation;
        public double learnRate;
        public boolean performBackpropagation;

        public Config setSynapseEvaluation(SynapseEvaluation synapseEvaluation) {
            this.synapseEvaluation = synapseEvaluation;
            return this;
        }

        public Config setLearnRate(double learnRate) {
            this.learnRate = learnRate;
            return this;
        }

        public Config setPerformBackpropagation(boolean performBackpropagation) {
            this.performBackpropagation = performBackpropagation;
            return this;
        }
    }
}

