/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.svm;

import org.encog.engine.util.ErrorCalculation;
import org.encog.mathutil.libsvm.svm;
import org.encog.mathutil.libsvm.svm_parameter;
import org.encog.mathutil.libsvm.svm_problem;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.svm.KernelType;
import org.encog.neural.networks.svm.SVMNetwork;
import org.encog.neural.networks.training.BasicTraining;
import org.encog.neural.networks.training.svm.EncodeSVMProblem;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SVMTrain
extends BasicTraining {
    private static final transient Logger LOGGER = LoggerFactory.getLogger(SVMTrain.class);
    public static final double DEFAULT_CONST_BEGIN = -5.0;
    public static final double DEFAULT_CONST_END = 15.0;
    public static final double DEFAULT_CONST_STEP = 2.0;
    public static final double DEFAULT_GAMMA_BEGIN = -10.0;
    public static final double DEFAULT_GAMMA_END = 10.0;
    public static final double DEFAULT_GAMMA_STEP = 1.0;
    private SVMNetwork network;
    private svm_problem[] problem;
    private int fold = 5;
    private double constBegin = -5.0;
    private double constStep = 15.0;
    private double constEnd = 2.0;
    private double gammaBegin = -10.0;
    private double gammaEnd = 10.0;
    private double gammaStep = 1.0;
    private double[] bestConst;
    private double[] bestGamma;
    private double[] bestError;
    private double[] currentConst;
    private double[] currentGamma;
    private boolean isSetup;
    private boolean trainingDone;

    public SVMTrain(BasicNetwork network, NeuralDataSet training) {
        this.network = (SVMNetwork)network;
        this.setTraining(training);
        this.isSetup = false;
        this.trainingDone = false;
        this.problem = new svm_problem[this.network.getOutputCount()];
        for (int i = 0; i < this.network.getOutputCount(); ++i) {
            this.problem[i] = EncodeSVMProblem.encode(training, i);
        }
    }

    public void train() {
        double gamma = 1.0 / (double)this.network.getInputCount();
        double c = 1.0;
        for (int i = 0; i < this.network.getOutputCount(); ++i) {
            this.train(i, gamma, c);
        }
    }

    public void train(int index, double gamma, double c) {
        this.network.getParams()[index].C = c;
        this.network.getParams()[index].gamma = gamma > 1.0E-7 ? 1.0 / (double)this.network.getInputCount() : gamma;
        this.network.getModels()[index] = svm.svm_train(this.problem[index], this.network.getParams()[index]);
    }

    public double crossValidate(int index, double gamma, double c) {
        double[] target = new double[this.problem[0].l];
        this.network.getParams()[index].C = c;
        this.network.getParams()[index].gamma = gamma;
        svm.svm_cross_validation(this.problem[index], this.network.getParams()[index], this.fold, target);
        return this.evaluate(this.network.getParams()[index], this.problem[index], target);
    }

    private double evaluate(svm_parameter param, svm_problem prob, double[] target) {
        int total_correct = 0;
        ErrorCalculation error = new ErrorCalculation();
        if (param.svm_type == 3 || param.svm_type == 4) {
            for (int i = 0; i < prob.l; ++i) {
                double ideal = prob.y[i];
                double actual = target[i];
                error.updateError(actual, ideal);
            }
            return error.calculate();
        }
        for (int i = 0; i < prob.l; ++i) {
            if (target[i] != prob.y[i]) continue;
            ++total_correct;
        }
        return 100.0 * (double)total_correct / (double)prob.l;
    }

    private void setup() {
        this.currentConst = new double[this.network.getOutputCount()];
        this.currentGamma = new double[this.network.getOutputCount()];
        this.bestConst = new double[this.network.getOutputCount()];
        this.bestGamma = new double[this.network.getOutputCount()];
        this.bestError = new double[this.network.getOutputCount()];
        for (int i = 0; i < this.network.getOutputCount(); ++i) {
            this.currentConst[i] = this.constBegin;
            this.currentGamma[i] = this.gammaBegin;
            this.bestError[i] = Double.POSITIVE_INFINITY;
        }
        this.isSetup = true;
    }

    @Override
    public void iteration() {
        if (!this.trainingDone) {
            if (!this.isSetup) {
                this.setup();
            }
            this.preIteration();
            if (this.network.getKernelType() == KernelType.RadialBasisFunction) {
                double totalError = 0.0;
                for (int i = 0; i < this.network.getOutputCount(); ++i) {
                    double e = this.crossValidate(i, this.currentGamma[i], this.currentConst[i]);
                    if (e < this.bestError[i]) {
                        this.bestConst[i] = this.currentConst[i];
                        this.bestGamma[i] = this.currentGamma[i];
                        this.bestError[i] = e;
                    }
                    int n = i;
                    this.currentConst[n] = this.currentConst[n] + this.constStep;
                    if (this.currentConst[i] > this.constEnd) {
                        this.currentConst[i] = this.constBegin;
                        int n2 = i;
                        this.currentGamma[n2] = this.currentGamma[n2] + this.gammaStep;
                        if (this.currentGamma[i] > this.gammaEnd) {
                            this.trainingDone = true;
                        }
                    }
                    totalError += this.bestError[i];
                }
                this.setError(totalError / (double)this.network.getOutputCount());
            } else {
                this.train();
            }
            this.postIteration();
        }
    }

    public svm_problem[] getProblem() {
        return this.problem;
    }

    public int getFold() {
        return this.fold;
    }

    public void setFold(int fold) {
        this.fold = fold;
    }

    public double getConstBegin() {
        return this.constBegin;
    }

    public void setConstBegin(double constBegin) {
        this.constBegin = constBegin;
    }

    public double getConstStep() {
        return this.constStep;
    }

    public void setConstStep(double constStep) {
        this.constStep = constStep;
    }

    public double getConstEnd() {
        return this.constEnd;
    }

    public void setConstEnd(double constEnd) {
        this.constEnd = constEnd;
    }

    public double getGammaBegin() {
        return this.gammaBegin;
    }

    public void setGammaBegin(double gammaBegin) {
        this.gammaBegin = gammaBegin;
    }

    public double getGammaEnd() {
        return this.gammaEnd;
    }

    public void setGammaEnd(double gammaEnd) {
        this.gammaEnd = gammaEnd;
    }

    public double getGammaStep() {
        return this.gammaStep;
    }

    public void setGammaStep(double gammaStep) {
        this.gammaStep = gammaStep;
    }

    @Override
    public void finishTraining() {
        for (int i = 0; i < this.network.getOutputCount(); ++i) {
            this.train(i, this.bestGamma[i], this.bestConst[i]);
        }
    }

    @Override
    public BasicNetwork getNetwork() {
        return this.network;
    }

    @Override
    public boolean isTrainingDone() {
        return this.trainingDone;
    }

    public void train(double gamma, double c) {
        for (int i = 0; i < this.network.getOutputCount(); ++i) {
            this.train(i, gamma, c);
        }
    }
}

