/*
 * Decompiled with CFR 0.152.
 */
package org.encog.ml.importance;

import org.encog.EncogError;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.mathutil.randomize.generate.GenerateRandom;
import org.encog.mathutil.randomize.generate.MersenneTwisterGenerateRandom;
import org.encog.ml.MLContext;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.importance.AbstractFeatureImportance;
import org.encog.ml.importance.FeatureRank;
import org.encog.util.EngineArray;

public class PerturbationFeatureImportanceCalc
extends AbstractFeatureImportance {
    private GenerateRandom rnd = new MersenneTwisterGenerateRandom();
    private double[] shuffleColumn;

    @Override
    public void performRanking() {
        throw new EncogError("This algorithm requires a dataset to measure performance against, please call performRanking with a dataset.");
    }

    private double calculateRegressionError(MLDataSet dataset, int perturbFeature) {
        ErrorCalculation errorCalculation = new ErrorCalculation();
        if (this.getModel() instanceof MLContext) {
            ((MLContext)((Object)this.getModel())).clearContext();
        }
        for (int i = 0; i < dataset.size(); ++i) {
            this.shuffleColumn[i] = dataset.get(i).getInput().getData(perturbFeature);
        }
        BasicMLData featureVector = new BasicMLData(dataset.getInputSize());
        try {
            int n = dataset.size();
            for (int i = 0; i < n; ++i) {
                MLDataPair pair = dataset.get(i);
                EngineArray.arrayCopy(pair.getInput().getData(), featureVector.getData());
                if (i != n - 1) {
                    int j = this.rnd.nextInt(dataset.size() - i);
                    double t = this.shuffleColumn[i];
                    this.shuffleColumn[i] = this.shuffleColumn[j];
                    this.shuffleColumn[j] = t;
                    featureVector.setData(perturbFeature, this.shuffleColumn[i]);
                }
                MLData actual = this.getModel().compute(featureVector);
                errorCalculation.updateError(actual.getData(), pair.getIdeal().getData(), pair.getSignificance());
            }
        }
        catch (EncogError e) {
            return Double.NaN;
        }
        return errorCalculation.calculate();
    }

    @Override
    public void performRanking(MLDataSet theDataset) {
        this.shuffleColumn = new double[theDataset.size()];
        double max = 0.0;
        for (int i = 0; i < this.getModel().getInputCount(); ++i) {
            FeatureRank fr = this.getFeatures().get(i);
            double e = this.calculateRegressionError(theDataset, i);
            fr.setTotalWeight(e);
            max = Math.max(max, e);
        }
        for (FeatureRank fr : this.getFeatures()) {
            fr.setImportancePercent(fr.getTotalWeight() / max);
        }
    }

    public GenerateRandom getRnd() {
        return this.rnd;
    }

    public void setRnd(GenerateRandom rnd) {
        this.rnd = rnd;
    }
}

