/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.cpn;

import org.encog.engine.util.BoundMath;
import org.encog.neural.data.NeuralData;
import org.encog.neural.data.NeuralDataPair;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.structure.FlatUpdateNeeded;
import org.encog.neural.networks.training.BasicTraining;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.cpn.FindCPN;

public class TrainInstar
extends BasicTraining
implements LearningRate {
    private final BasicNetwork network;
    private final NeuralDataSet training;
    private double learningRate;
    private boolean mustInit = true;
    private final FindCPN parts;

    public TrainInstar(BasicNetwork network, NeuralDataSet training, double learningRate) {
        this.network = network;
        this.training = training;
        this.learningRate = learningRate;
        this.parts = new FindCPN(network);
    }

    @Override
    public double getLearningRate() {
        return this.learningRate;
    }

    @Override
    public BasicNetwork getNetwork() {
        return this.network;
    }

    private void initWeights() {
        int i = 0;
        for (NeuralDataPair pair : this.training) {
            for (int j = 0; j < this.parts.getInputLayer().getNeuronCount(); ++j) {
                this.parts.getInstarSynapse().getMatrix().set(j, i, pair.getInput().getData(j));
            }
            ++i;
        }
        this.network.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten);
        this.mustInit = false;
    }

    @Override
    public void iteration() {
        if (this.mustInit) {
            this.initWeights();
        }
        double worstDistance = Double.NEGATIVE_INFINITY;
        for (NeuralDataPair pair : this.training) {
            NeuralData out = this.parts.getInstarSynapse().compute(pair.getInput());
            int winner = this.parts.winner(out);
            double distance = 0.0;
            for (int i = 0; i < pair.getInput().size(); ++i) {
                double diff = pair.getInput().getData(i) - this.parts.getInstarSynapse().getMatrix().get(i, winner);
                distance += diff * diff;
            }
            if ((distance = BoundMath.sqrt(distance)) > worstDistance) {
                worstDistance = distance;
            }
            for (int j = 0; j < this.parts.getInstarSynapse().getFromNeuronCount(); ++j) {
                double delta = this.learningRate * (pair.getInput().getData(j) - this.parts.getInstarSynapse().getMatrix().get(j, winner));
                this.parts.getInstarSynapse().getMatrix().add(j, winner, delta);
            }
        }
        this.network.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten);
        this.setError(worstDistance);
    }

    @Override
    public void setLearningRate(double rate) {
        this.learningRate = rate;
    }
}

