/*
 * Decompiled with CFR 0.152.
 */
package network.aika.training;

import java.util.TreeSet;
import network.aika.Document;
import network.aika.neuron.INeuron;
import network.aika.neuron.Neuron;
import network.aika.neuron.Synapse;
import network.aika.neuron.activation.Activation;
import network.aika.training.SynapseEvaluation;

public class SupervisedTraining {
    public Document doc;
    public TreeSet<Activation> targetActivations = new TreeSet();
    public TreeSet<Activation> errorSignalActivations = new TreeSet();
    public BackPropagationQueue queue = new BackPropagationQueue();

    public SupervisedTraining(Document doc) {
        this.doc = doc;
    }

    public void train(Config config) {
        this.targetActivations.forEach(tAct -> this.computeOutputErrorSignal((Activation)tAct));
        if (config.performBackpropagation) {
            this.queue.backpropagtion();
        }
        for (Activation act : this.errorSignalActivations) {
            this.train(act.getINeuron(), act, config.learnRate, config.synapseEvaluation);
        }
        this.errorSignalActivations.clear();
    }

    public void train(INeuron n, Activation targetAct, double learnRate, SynapseEvaluation se) {
        if (Math.abs(targetAct.errorSignal) < INeuron.TOLERANCE) {
            return;
        }
        long v = this.doc.visitedCounter++;
        double x = learnRate * targetAct.errorSignal;
        n.changeBias(x);
        this.doc.getActivations(true).forEach(iAct -> {
            SynapseEvaluation.Result r = se.evaluate(null, (Activation)iAct, targetAct);
            if (r != null) {
                this.trainSynapse(n, (Activation)iAct, r, x, v);
            }
        });
    }

    public void computeOutputErrorSignal(Activation act) {
        if (act.targetValue != null) {
            if (act.targetValue > 0.0) {
                act.errorSignal += act.targetValue - act.getFinalState().value;
            } else if (act.upperBound > 0.0) {
                act.errorSignal += act.upperBound - 1.0;
            }
        }
        this.updateErrorSignal(act);
    }

    public void computeBackpropagationErrorSignal(Activation act) {
        for (Activation.Link l : act.neuronOutputs.values()) {
            Synapse s = l.synapse;
            Activation oAct = l.output;
            act.errorSignal += s.weight * oAct.errorSignal * (1.0 - act.getFinalState().value);
        }
        this.updateErrorSignal(act);
    }

    public void updateErrorSignal(Activation act) {
        if (act.errorSignal != 0.0) {
            this.errorSignalActivations.add(act);
            for (Activation.Link l : act.neuronInputs.values()) {
                this.queue.add(l.input);
            }
        }
    }

    private void trainSynapse(INeuron n, Activation iAct, SynapseEvaluation.Result r, double x, long v) {
        if (iAct.visited == v) {
            return;
        }
        iAct.visited = v;
        INeuron inputNeuron = iAct.getINeuron();
        if (inputNeuron == n) {
            return;
        }
        double deltaW = x * r.significance * iAct.getFinalState().value;
        Synapse synapse = Synapse.createOrLookup(this.doc, null, r.synapseKey, r.relations, r.distanceFunction, (Neuron)inputNeuron.provider, (Neuron)n.provider);
        synapse.updateDelta(this.doc, deltaW, 0.0);
        r.deleteMode.checkIfDelete(synapse, true);
    }

    public class BackPropagationQueue {
        public final TreeSet<Activation> queue = new TreeSet((act1, act2) -> {
            Activation.State fs1 = act1.getFinalState();
            Activation.State fs2 = act2.getFinalState();
            int r = Integer.compare(fs2.fired, fs1.fired);
            if (r != 0) {
                return r;
            }
            return Integer.compare(act1.id, act2.id);
        });
        private long queueIdCounter = 0L;

        public void add(Activation act) {
            if (!act.isQueued) {
                act.isQueued = true;
                act.queueId = this.queueIdCounter++;
                this.queue.add(act);
            }
        }

        public void backpropagtion() {
            while (!this.queue.isEmpty()) {
                Activation act = this.queue.pollFirst();
                act.isQueued = false;
                SupervisedTraining.this.computeBackpropagationErrorSignal(act);
            }
        }
    }

    public static class Config {
        public SynapseEvaluation synapseEvaluation;
        public double learnRate;
        public boolean performBackpropagation;

        public Config setSynapseEvaluation(SynapseEvaluation synapseEvaluation) {
            this.synapseEvaluation = synapseEvaluation;
            return this;
        }

        public Config setLearnRate(double learnRate) {
            this.learnRate = learnRate;
            return this;
        }

        public Config setPerformBackpropagation(boolean performBackpropagation) {
            this.performBackpropagation = performBackpropagation;
            return this;
        }
    }
}

