/*
 * Decompiled with CFR 0.152.
 */
package org.encog.ml.svm.training;

import org.encog.mathutil.error.ErrorCalculation;
import org.encog.mathutil.libsvm.svm;
import org.encog.mathutil.libsvm.svm_parameter;
import org.encog.mathutil.libsvm.svm_problem;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.svm.SVM;
import org.encog.ml.svm.training.EncodeSVMProblem;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.networks.training.propagation.TrainingContinuation;

public class SVMTrain
extends BasicTraining {
    public static final double DEFAULT_CONST_BEGIN = -5.0;
    public static final double DEFAULT_CONST_END = 15.0;
    public static final double DEFAULT_CONST_STEP = 2.0;
    public static final double DEFAULT_GAMMA_BEGIN = -10.0;
    public static final double DEFAULT_GAMMA_END = 10.0;
    public static final double DEFAULT_GAMMA_STEP = 1.0;
    private final SVM network;
    private svm_problem problem;
    private int fold = 0;
    private boolean trainingDone;
    private double gamma;
    private double c;

    public SVMTrain(SVM method, MLDataSet dataSet) {
        super(TrainingImplementationType.OnePass);
        this.network = method;
        this.setTraining(dataSet);
        this.trainingDone = false;
        this.problem = EncodeSVMProblem.encode(dataSet, 0);
        this.gamma = 1.0 / (double)this.network.getInputCount();
        this.c = 1.0;
    }

    @Override
    public final boolean canContinue() {
        return false;
    }

    private double evaluate(svm_parameter param, svm_problem prob, double[] target) {
        int totalCorrect = 0;
        ErrorCalculation error = new ErrorCalculation();
        if (param.svm_type == 3 || param.svm_type == 4) {
            for (int i = 0; i < prob.l; ++i) {
                double ideal = prob.y[i];
                double actual = target[i];
                error.updateError(actual, ideal);
            }
            return error.calculate();
        }
        for (int i = 0; i < prob.l; ++i) {
            if (target[i] != prob.y[i]) continue;
            ++totalCorrect;
        }
        return 100.0 * (double)totalCorrect / (double)prob.l;
    }

    public final double getC() {
        return this.c;
    }

    public final int getFold() {
        return this.fold;
    }

    public final double getGamma() {
        return this.gamma;
    }

    @Override
    public final MLMethod getMethod() {
        return this.network;
    }

    public final svm_problem getProblem() {
        return this.problem;
    }

    @Override
    public final boolean isTrainingDone() {
        return this.trainingDone;
    }

    @Override
    public final void iteration() {
        this.network.getParams().C = this.c;
        this.network.getParams().gamma = this.gamma;
        if (this.fold > 1) {
            double[] target = new double[this.problem.l];
            svm.svm_cross_validation(this.problem, this.network.getParams(), this.fold, target);
            this.network.setModel(null);
            this.setError(this.evaluate(this.network.getParams(), this.problem, target));
        } else {
            this.network.setModel(svm.svm_train(this.problem, this.network.getParams()));
            this.setError(this.network.calculateError(this.getTraining()));
        }
        this.trainingDone = true;
    }

    @Override
    public final TrainingContinuation pause() {
        return null;
    }

    @Override
    public void resume(TrainingContinuation state) {
    }

    public final void setC(double theC) {
        this.c = theC;
    }

    public final void setFold(int theFold) {
        this.fold = theFold;
    }

    public final void setGamma(double theGamma) {
        this.gamma = theGamma;
    }
}

