/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.svd;

import org.encog.engine.network.rbf.RadialBasisFunction;
import org.encog.engine.util.ObjectPair;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.layers.RadialBasisFunctionLayer;
import org.encog.neural.networks.structure.FlatUpdateNeeded;
import org.encog.neural.networks.training.BasicTraining;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.svd.SVD;
import org.encog.util.simple.TrainingSetUtil;

public class SVDTraining
extends BasicTraining {
    private BasicNetwork network;
    private RadialBasisFunctionLayer rbfLayer;

    public SVDTraining(BasicNetwork network, NeuralDataSet training) {
        Layer outputLayer = network.getLayer("OUTPUT");
        if (outputLayer == null) {
            throw new TrainingError("SVD requires an output layer.");
        }
        if (outputLayer.getNeuronCount() != 1) {
            throw new TrainingError("SVD requires an output layer with a single neuron.");
        }
        if (network.getLayer("RBF") == null) {
            throw new TrainingError("SVD is only tested to work on radial basis function networks.");
        }
        this.rbfLayer = (RadialBasisFunctionLayer)network.getLayer("RBF");
        this.setTraining(training);
        this.network = network;
    }

    @Override
    public BasicNetwork getNetwork() {
        return this.network;
    }

    @Override
    public void iteration() {
        RadialBasisFunction[] funcs = new RadialBasisFunction[this.rbfLayer.getNeuronCount()];
        for (int i = 0; i < this.rbfLayer.getNeuronCount(); ++i) {
            RadialBasisFunction basisFunc;
            funcs[i] = basisFunc = this.rbfLayer.getRadialBasisFunction()[i];
        }
        ObjectPair<double[][], double[][]> data = TrainingSetUtil.trainingToArray(this.getTraining());
        double[][] weights = this.network.getStructure().getSynapses().get(0).getMatrix().getData();
        this.setError(SVD.svdfit(data.getA(), data.getB(), weights, funcs));
        this.network.getStructure().setFlatUpdate(FlatUpdateNeeded.Flatten);
    }
}

