/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.hebbian;

import java.util.ArrayList;
import java.util.List;
import org.encog.mathutil.matrices.Matrix;
import org.encog.neural.data.NeuralData;
import org.encog.neural.data.NeuralDataPair;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.synapse.Synapse;
import org.encog.neural.networks.training.BasicTraining;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.TrainingError;

public class HebbianTraining
extends BasicTraining
implements LearningRate {
    private double learningRate;
    private final BasicNetwork network;
    private final NeuralDataSet training;
    private final boolean supervised;
    private final boolean oja;
    private final Layer outputLayer;
    private List<Synapse> outputSynapse = new ArrayList<Synapse>();

    public HebbianTraining(BasicNetwork network, NeuralDataSet training, boolean oja, double learningRate) {
        this.network = network;
        this.training = training;
        this.learningRate = learningRate;
        this.supervised = training.getIdealSize() > 0;
        this.oja = oja;
        this.outputLayer = this.network.getLayer("OUTPUT");
        if (this.outputLayer == null) {
            throw new TrainingError("Can't use Hebbian training without an output layer.");
        }
        if (this.oja && this.supervised) {
            throw new TrainingError("Can't use OJA Hebbian training with supervised data.");
        }
        this.outputSynapse = this.network.getStructure().getPreviousSynapses(this.outputLayer);
        if (this.outputSynapse.size() == 0) {
            throw new TrainingError("Can't use Hebbian learning, the output layer has no inbound synapses.");
        }
    }

    @Override
    public double getLearningRate() {
        return this.learningRate;
    }

    @Override
    public BasicNetwork getNetwork() {
        return this.network;
    }

    @Override
    public NeuralDataSet getTraining() {
        return this.training;
    }

    public boolean isOja() {
        return this.oja;
    }

    public boolean isSupervised() {
        return this.supervised;
    }

    @Override
    public void iteration() {
        this.preIteration();
        for (NeuralDataPair pair : this.training) {
            for (Synapse synapse : this.outputSynapse) {
                this.trainSynapse(synapse, pair);
            }
        }
        this.postIteration();
    }

    @Override
    public void setLearningRate(double rate) {
        this.learningRate = rate;
    }

    private void trainSynapse(Synapse synapse, NeuralDataPair pair) {
        NeuralData outputData = this.network.compute(pair.getInput());
        double[] input = pair.getInput().getData();
        double[] output = outputData.getData();
        Matrix matrix = synapse.getMatrix();
        for (int toNeuron = 0; toNeuron < synapse.getToNeuronCount(); ++toNeuron) {
            double z = this.supervised ? pair.getIdeal().getData(toNeuron) : this.learningRate;
            for (int fromNeuron = 0; fromNeuron < synapse.getFromNeuronCount(); ++fromNeuron) {
                double deltaWeight = this.oja ? input[fromNeuron] - output[toNeuron] * matrix.get(fromNeuron, toNeuron) * output[toNeuron] * this.learningRate : input[toNeuron] * output[toNeuron] * z;
                matrix.add(fromNeuron, toNeuron, deltaWeight);
            }
        }
    }
}

