/*
 * Decompiled with CFR 0.152.
 */
package org.riversun.ml.spark;

import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.GBTRegressionModel;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.riversun.ml.spark.Importance;

public class FeatureImportance {
    private PredictionModel<Vector, ?> model;
    private StructType schema;
    private Order sort;

    private FeatureImportance(Builder builder) {
        this.model = builder.model;
        this.schema = builder.schema;
        this.sort = builder.sort;
    }

    public List<Importance> getResult() {
        Vector featureImportances;
        if (this.model instanceof GBTRegressionModel) {
            featureImportances = ((GBTRegressionModel)this.model).featureImportances();
        } else if (this.model instanceof GBTClassificationModel) {
            featureImportances = ((GBTClassificationModel)this.model).featureImportances();
        } else if (this.model instanceof RandomForestRegressionModel) {
            featureImportances = ((RandomForestRegressionModel)this.model).featureImportances();
        } else if (this.model instanceof RandomForestClassificationModel) {
            featureImportances = ((RandomForestClassificationModel)this.model).featureImportances();
        } else if (this.model instanceof DecisionTreeRegressionModel) {
            featureImportances = ((DecisionTreeRegressionModel)this.model).featureImportances();
        } else if (this.model instanceof DecisionTreeClassificationModel) {
            featureImportances = ((DecisionTreeClassificationModel)this.model).featureImportances();
        } else {
            throw new RuntimeException(this.model + " doesn't have feature importances." + "You should specify an instance of " + "GBTRegressionModel,GBTClassificationModel," + "RandomForestRegressionModel,RandomForestClassificationModel," + "DecisionTreeRegressionModel,DecisionTreeClassificationModel");
        }
        return this.zipImportances(featureImportances, this.model.getFeaturesCol(), this.schema);
    }

    private List<Importance> zipImportances(Vector featureImportances, String featuresCol, StructType schema) {
        List<Importance> finalImportanceList;
        int indexOfFeaturesCol = (Integer)schema.getFieldIndex(featuresCol).get();
        StructField featuresField = schema.fields()[indexOfFeaturesCol];
        Metadata metadata = featuresField.metadata();
        Metadata featuresFieldAttrs = metadata.getMetadata("ml_attr").getMetadata("attrs");
        HashMap idNameMap = new HashMap();
        String[] fieldKeys = new String[]{"nominal", "numeric", "binary"};
        Collector<Metadata, ?, HashMap> metaDataMapperFunc = Collectors.toMap(metaData -> (int)metaData.getLong("idx"), metaData -> metaData.getString("name"), (oldVal, newVal) -> newVal, HashMap::new);
        for (String fieldKey : fieldKeys) {
            if (!featuresFieldAttrs.contains(fieldKey)) continue;
            idNameMap.putAll(Arrays.stream(featuresFieldAttrs.getMetadataArray(fieldKey)).collect(metaDataMapperFunc));
        }
        double[] importanceScores = featureImportances.toArray();
        List rawImportanceList = IntStream.range(0, importanceScores.length).mapToObj(idx -> new Importance(idx, (String)idNameMap.get(idx), importanceScores[idx], 0)).collect(Collectors.toList());
        List<Importance> descSortedImportanceList = rawImportanceList.stream().sorted(Comparator.comparingDouble(ifeature -> ifeature.score).reversed()).collect(Collectors.toList());
        for (int i = 0; i < descSortedImportanceList.size(); ++i) {
            ((Importance)descSortedImportanceList.get((int)i)).rank = i;
        }
        switch (this.sort) {
            case ASCENDING: {
                List ascSortedImportantFeatureList = descSortedImportanceList.stream().sorted(Comparator.comparingDouble(ifeature -> ifeature.score)).collect(Collectors.toList());
                finalImportanceList = ascSortedImportantFeatureList;
                break;
            }
            case DESCENDING: {
                finalImportanceList = descSortedImportanceList;
                break;
            }
            default: {
                finalImportanceList = rawImportanceList;
            }
        }
        return finalImportanceList;
    }

    public static class Builder {
        private PredictionModel<Vector, ?> model;
        private StructType schema;
        private Order sort = Order.DESCENDING;

        public Builder(PredictionModel<Vector, ?> model, StructType schema) {
            this.model = model;
            this.schema = schema;
        }

        public Builder sort(Order sort) {
            this.sort = sort;
            return this;
        }

        public FeatureImportance build() {
            if (this.model == null || this.schema == null) {
                throw new NullPointerException();
            }
            return new FeatureImportance(this);
        }
    }

    public static enum Order {
        ASCENDING,
        DESCENDING,
        UNSORTED;

    }
}

