/*
 * Decompiled with CFR 0.152.
 */
package edu.columbia.tjw.item.spark;

import java.util.Arrays;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.ProbabilisticClassificationModel;
import org.apache.spark.ml.classification.ProbabilisticClassifier;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions;

public class ClassificationModelEvaluator {
    public static <W extends ProbabilisticClassificationModel<Vector, W>, M extends ProbabilisticClassifier<Vector, M, W>> EvaluationResult evaluate(M classifier_, String label_, Dataset<Row> fitting, Dataset<Row> testing, long prngSeed_, int[] layers_) {
        long start = System.currentTimeMillis();
        MultilayerPerceptronClassificationModel evalModel = (MultilayerPerceptronClassificationModel)classifier_.fit(fitting);
        long elapsed = System.currentTimeMillis() - start;
        EntropyResult fitResult = ClassificationModelEvaluator.computeEntropy(fitting, (ClassificationModel)evalModel);
        EntropyResult testResult = ClassificationModelEvaluator.computeEntropy(testing, (ClassificationModel)evalModel);
        EvaluationResult result = new EvaluationResult(label_, prngSeed_, Arrays.toString(layers_), evalModel, elapsed, fitResult, testResult);
        return result;
    }

    private static EntropyResult computeEntropy(Dataset<Row> data, ClassificationModel model) {
        long start = System.currentTimeMillis();
        Dataset testingResults = model.transform(data);
        testingResults = testingResults.withColumn("prob_array", functions.expr((String)"toArrayLambda(probability)"));
        testingResults = testingResults.withColumn("prob_p", functions.expr((String)"prob_array[0]"));
        testingResults = testingResults.withColumn("prob_c", functions.expr((String)"prob_array[1]"));
        testingResults = testingResults.withColumn("prob_3", functions.expr((String)"prob_array[2]"));
        testingResults = testingResults.withColumn("distEntropy", functions.expr((String)"-1.0 * ((prob_p*log(prob_p))  + (prob_c*log(prob_c)) + (prob_3*log(prob_3)))"));
        testingResults = testingResults.withColumn("crossEntropy", functions.expr((String)" -1.0 * log(prob_array[next_status])"));
        Dataset result = testingResults.select(new Column[]{functions.expr((String)"count(*)"), functions.expr((String)"sum(crossEntropy)/count(*)"), functions.expr((String)"sum(distEntropy)/count(*)")});
        Row fitRow = (Row)result.toLocalIterator().next();
        long fitCount = fitRow.getLong(0);
        double fitEntropy = fitRow.getDouble(1);
        double fitDistEntropy = fitRow.getDouble(2);
        long elapsed = System.currentTimeMillis() - start;
        EntropyResult output = new EntropyResult(elapsed, fitCount, fitEntropy, fitDistEntropy);
        return output;
    }

    public static final class EvaluationResult {
        private final String _label;
        private final long _prngSeed;
        private final String _layerString;
        private final MultilayerPerceptronClassificationModel _model;
        private final long _fittingTime;
        private final EntropyResult _fittingEntropy;
        private final EntropyResult _testingEntropy;

        public EvaluationResult(String label_, long prngSeed_, String layerString_, MultilayerPerceptronClassificationModel model_, long fittingTime_, EntropyResult fittingEntropy_, EntropyResult testingEntropy_) {
            this._label = label_;
            this._prngSeed = prngSeed_;
            this._layerString = layerString_;
            this._model = model_;
            this._fittingTime = fittingTime_;
            this._fittingEntropy = fittingEntropy_;
            this._testingEntropy = testingEntropy_;
        }

        public String getLabel() {
            return this._label;
        }

        public MultilayerPerceptronClassificationModel getModel() {
            return this._model;
        }

        public long getFittingTime() {
            return this._fittingTime;
        }

        public EntropyResult getFittingEntropy() {
            return this._fittingEntropy;
        }

        public EntropyResult getTestingEntropy() {
            return this._testingEntropy;
        }

        public int getParamCount() {
            return this._model.weights().size();
        }

        public long getPrngSeed() {
            return this._prngSeed;
        }

        public String getLayerString() {
            return this._layerString;
        }
    }

    public static final class EntropyResult {
        private final long _calcTime;
        private final long _rowCount;
        private final double _crossEntropy;
        private final double _distEntropy;

        public EntropyResult(long calcTime_, long rowCount_, double crossEntropy_, double distEntropy_) {
            this._calcTime = calcTime_;
            this._rowCount = rowCount_;
            this._crossEntropy = crossEntropy_;
            this._distEntropy = distEntropy_;
        }

        public long getCalcTime() {
            return this._calcTime;
        }

        public long getRowCount() {
            return this._rowCount;
        }

        public double getCrossEntropy() {
            return this._crossEntropy;
        }

        public double getDistEntropy() {
            return this._distEntropy;
        }
    }
}

