/*
 * Decompiled with CFR 0.152.
 */
package interpret.glassbox.ebm;

import interpret.glassbox.ebm.HasExplainableBooster;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import numpy.core.ScalarUtil;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Discretize;
import org.dmg.pmml.DiscretizeBin;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldColumnPair;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.Interval;
import org.dmg.pmml.MapValues;
import org.dmg.pmml.OpType;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.IndexFeature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.TypeUtil;
import org.jpmml.converter.XMLUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.HasArray;
import sklearn.Estimator;

public class ExplainableBoostingUtil {
    private ExplainableBoostingUtil() {
    }

    public static <E extends Estimator> List<Feature> encodeExplainableBooster(E estimator, Schema schema) {
        List<List<?>> bins = ((HasExplainableBooster)estimator).getBins();
        List<String> featureTypesIn = ((HasExplainableBooster)estimator).getFeatureTypesIn();
        List<Object[]> termFeatures = ((HasExplainableBooster)estimator).getTermFeatures();
        List<HasArray> termScores = ((HasExplainableBooster)estimator).getTermScores();
        ClassDictUtil.checkSize((Collection[])new Collection[]{bins, featureTypesIn});
        ClassDictUtil.checkSize((Collection[])new Collection[]{termFeatures, termScores});
        ModelEncoder encoder = schema.getEncoder();
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        ArrayList<List<CategoricalFeature>> binLevelFeatures = new ArrayList<List<CategoricalFeature>>();
        for (int i = 0; i < bins.size(); ++i) {
            Feature feature = (Feature)features.get(i);
            List<?> binLevels = bins.get(i);
            String featureType = featureTypesIn.get(i);
            binLevelFeatures.add(ExplainableBoostingUtil.encodeBinLevelFeatures(feature, binLevels, featureType, encoder));
        }
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (int i = 0; i < termFeatures.size(); ++i) {
            Object[] termFeature = termFeatures.get(i);
            HasArray termScore = termScores.get(i);
            Feature feature = ExplainableBoostingUtil.encodeLookupFeature(termFeature, termScore, binLevelFeatures, (PMMLEncoder)encoder);
            result.add(feature);
        }
        return result;
    }

    private static List<CategoricalFeature> encodeBinLevelFeatures(Feature feature, List<?> binLevels, String featureType, ModelEncoder encoder) {
        ArrayList<CategoricalFeature> result = new ArrayList<CategoricalFeature>();
        block8: for (int i = 0; i < binLevels.size(); ++i) {
            Object binLevel = binLevels.get(i);
            switch (featureType) {
                case "continuous": {
                    result.add((CategoricalFeature)ExplainableBoostingUtil.binContinuous(feature, (HasArray)binLevel, binLevels.size() > 1 ? Integer.valueOf(i) : null, encoder));
                    continue block8;
                }
                case "nominal": {
                    result.add(ExplainableBoostingUtil.binNominal(feature, (Map)binLevel, (PMMLEncoder)encoder));
                    continue block8;
                }
                default: {
                    throw new IllegalArgumentException(featureType);
                }
            }
        }
        return result;
    }

    private static IndexFeature binContinuous(Feature feature, HasArray binLevel, Integer binLevelIndex, ModelEncoder encoder) {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        Discretize discretize = new Discretize(continuousFeature.getName());
        List bins = binLevel.getArrayContent();
        if (bins.isEmpty()) {
            throw new IllegalArgumentException();
        }
        ArrayList<Integer> labelCategories = new ArrayList<Integer>();
        for (int j = 0; j <= bins.size(); ++j) {
            Number leftMargin = null;
            Number rightMargin = null;
            if (j == 0) {
                rightMargin = (Number)bins.get(j);
            } else if (j == bins.size()) {
                leftMargin = (Number)bins.get(j - 1);
            } else {
                leftMargin = (Number)bins.get(j - 1);
                rightMargin = (Number)bins.get(j);
            }
            Integer label = j;
            labelCategories.add(label);
            Interval interval = new Interval(Interval.Closure.CLOSED_OPEN).setLeftMargin(leftMargin).setRightMargin(rightMargin);
            DiscretizeBin discretizeBin = new DiscretizeBin((Object)label, interval);
            discretize.addDiscretizeBins(new DiscretizeBin[]{discretizeBin});
        }
        String name = binLevelIndex != null ? FieldNameUtil.create((String)"bin", (Object[])new Object[]{continuousFeature, binLevelIndex}) : FieldNameUtil.create((String)"bin", (Object[])new Object[]{continuousFeature});
        DerivedField derivedField = encoder.createDerivedField(name, OpType.CATEGORICAL, DataType.INTEGER, (Expression)discretize);
        return new IndexFeature((PMMLEncoder)encoder, (Field)derivedField, labelCategories);
    }

    private static CategoricalFeature binNominal(Feature feature, Map<?, ?> binLevel, PMMLEncoder encoder) {
        List entries = binLevel.entrySet().stream().sorted((left, right) -> ((Comparable)left.getValue()).compareTo(right.getValue())).collect(Collectors.toList());
        ArrayList<Object> categories = new ArrayList<Object>();
        for (int j = 0; j < entries.size(); ++j) {
            Map.Entry entry = (Map.Entry)entries.get(j);
            Object category = ScalarUtil.decode(entry.getKey());
            Number bin = (Number)entry.getValue();
            if (bin.intValue() != j + 1) {
                throw new IllegalArgumentException();
            }
            categories.add(category);
        }
        Field field = encoder.toCategorical(feature.getName(), categories);
        field.setDataType(TypeUtil.getDataType(categories, (DataType)DataType.STRING));
        return new CategoricalFeature(encoder, field, categories);
    }

    private static Feature encodeLookupFeature(Object[] termFeature, HasArray termScore, List<List<CategoricalFeature>> binLevelFeatures, PMMLEncoder encoder) {
        List<Number> outputValues;
        int columns;
        int rows;
        int[] termScoreShape = termScore.getArrayShape();
        List termScoreContent = termScore.getArrayContent();
        if (termFeature.length == 1) {
            rows = termScoreShape[0] - 2;
            columns = 1;
        } else if (termFeature.length == 2) {
            rows = termScoreShape[0] - 2;
            columns = termScoreShape[1] - 2;
        } else {
            throw new IllegalArgumentException();
        }
        ArrayList<CategoricalFeature> inputFeatures = new ArrayList<CategoricalFeature>();
        String outputColumn = "data:output";
        MapValues mapValues = new MapValues().setMapMissingTo((Object)0.0).setOutputColumn(outputColumn);
        LinkedHashMap<Object, List<Object>> data = new LinkedHashMap<Object, List<Object>>();
        for (int j = 0; j < termFeature.length; ++j) {
            ArrayList<Object> categoryValues;
            Object inputColumn;
            CategoricalFeature binnedFeature;
            Integer featureIndex = (Integer)termFeature[j];
            List<CategoricalFeature> binnedFeatures = binLevelFeatures.get(featureIndex);
            if (termFeature.length == 1) {
                binnedFeature = binnedFeatures.get(0);
                inputColumn = "data:input";
                categoryValues = binnedFeature.getValues();
            } else {
                binnedFeature = binnedFeatures.get(Math.min(binnedFeatures.size() - 1, 1));
                inputColumn = "data:" + XMLUtil.createTagName((String)("input_" + j));
                categoryValues = new ArrayList<Object>();
                int max = rows * columns;
                for (int k = 0; k < max; ++k) {
                    int index;
                    if (j == 0) {
                        index = k / columns;
                    } else if (j == 1) {
                        index = k % columns;
                    } else {
                        throw new IllegalArgumentException();
                    }
                    categoryValues.add(binnedFeature.getValue(index));
                }
            }
            inputFeatures.add(binnedFeature);
            FieldColumnPair fieldColumnPair = new FieldColumnPair(binnedFeature.getName(), (String)inputColumn);
            mapValues.addFieldColumnPairs(new FieldColumnPair[]{fieldColumnPair});
            data.put(inputColumn, categoryValues);
        }
        if (termFeature.length == 1) {
            outputValues = termScoreContent.subList(1, termScoreContent.size() - 1);
        } else if (termFeature.length == 2) {
            outputValues = new ArrayList();
            for (int row = 0; row < rows; ++row) {
                for (int column = 0; column < columns; ++column) {
                    Number value = (Number)termScoreContent.get((row + 1) * (columns + 2) + (column + 1));
                    outputValues.add(value);
                }
            }
        } else {
            throw new IllegalArgumentException();
        }
        data.put(outputColumn, outputValues);
        InlineTable inlineTable = PMMLUtil.createInlineTable(data);
        mapValues.setInlineTable(inlineTable);
        String name = FieldNameUtil.create((String)"lookup", inputFeatures);
        DerivedField derivedField = encoder.createDerivedField(name, OpType.CATEGORICAL, DataType.DOUBLE, (Expression)mapValues);
        ContinuousFeature feature = new ContinuousFeature(encoder, (Field)derivedField);
        return feature;
    }
}

