/*
 * Decompiled with CFR 0.152.
 */
package optbinning;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import optbinning.BinnedFeature;
import optbinning.OptimalBinningUtil;
import org.dmg.pmml.Apply;
import org.dmg.pmml.CompoundPredicate;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Discretize;
import org.dmg.pmml.DiscretizeBin;
import org.dmg.pmml.Expression;
import org.dmg.pmml.HasMapMissingTo;
import org.dmg.pmml.Interval;
import org.dmg.pmml.MapValues;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.TypeUtil;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Transformer;

public class OptimalBinning
extends Transformer {
    public static final Double CATEGORY_MISSING = 0.0;
    public static final Double CATEGORY_SPECIAL = 0.0;
    private static final String DTYPE_CATEGORICAL = "categorical";
    private static final String DTYPE_NUMERICAL = "numerical";
    private static final String METRIC_EVENT_RATE = "event_rate";
    private static final String METRIC_WOE = "woe";

    public OptimalBinning(String module, String name) {
        super(module, name);
    }

    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        Discretize expression;
        String dtype = this.getDType();
        List<Number> specialCodes = this.getSpecialCodes();
        List<Number> splits = this.getSplitsOptimal();
        List<Double> categoriesOut = this.getCategoriesOut();
        ClassDictUtil.checkSize((int)(splits.size() + 3), (Collection[])new Collection[]{categoriesOut});
        categoriesOut = categoriesOut.subList(0, categoriesOut.size() - 2);
        SchemaUtil.checkSize((int)1, features);
        Object feature = features.get(0);
        PredicateManager predicateManager = new PredicateManager();
        ArrayList<Predicate> predicates = new ArrayList<Predicate>();
        if (!splits.isEmpty()) {
            OptimalBinningUtil.checkIncreasingOrder(splits);
            switch (dtype) {
                case "numerical": {
                    expression = this.encodeNumericalBinning((Feature)feature, splits, categoriesOut, predicateManager, (List<Predicate>)predicates);
                    break;
                }
                case "categorical": {
                    List<Object> categoriesIn = this.getCategoriesIn();
                    if (feature instanceof WildcardFeature) {
                        WildcardFeature wildcardFeature = (WildcardFeature)feature;
                        DataType dataType = TypeUtil.getDataType(categoriesIn, (DataType)DataType.STRING);
                        DataField dataField = wildcardFeature.getField();
                        if (dataField.requireDataType() != dataType) {
                            dataField.setDataType(dataType);
                        }
                        feature = wildcardFeature.toCategoricalFeature(categoriesIn);
                    }
                    expression = this.encodeCategoricalBinning((Feature)feature, splits, categoriesIn, categoriesOut, predicateManager, (List<Predicate>)predicates);
                    break;
                }
                default: {
                    throw new IllegalArgumentException(dtype);
                }
            }
        } else {
            Apply apply = ExpressionUtil.createApply((String)"if", (Expression[])new Expression[]{ExpressionUtil.createApply((String)"isNotMissing", (Expression[])new Expression[]{feature.ref()}), ExpressionUtil.createConstant(null, (Object)categoriesOut.get(0)), ExpressionUtil.createConstant((Number)CATEGORY_MISSING)});
            expression = apply;
        }
        List<Object> categories = categoriesOut.stream().distinct().collect(Collectors.toList());
        if (!specialCodes.isEmpty()) {
            Apply valueApply = ExpressionUtil.createValueApply((Expression)feature.ref(), specialCodes);
            if (expression instanceof HasMapMissingTo) {
                HasMapMissingTo hasMapMissingTo = (HasMapMissingTo)expression;
                valueApply.setMapMissingTo(hasMapMissingTo.getMapMissingTo());
            }
            expression = ExpressionUtil.createApply((String)"if", (Expression[])new Expression[]{valueApply, ExpressionUtil.createConstant((Number)CATEGORY_SPECIAL), expression});
            Predicate specialPredicate = predicateManager.createPredicate(feature, specialCodes);
            predicates.add(specialPredicate);
            categories = OptimalBinningUtil.ensureCategory(categories, CATEGORY_SPECIAL);
        } else {
            predicates.add(null);
        }
        Predicate missingPredicate = predicateManager.createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
        predicates.add(missingPredicate);
        categories = OptimalBinningUtil.ensureCategory(categories, CATEGORY_MISSING);
        DerivedField derivedField = encoder.createDerivedField(this.createFieldName("optBinning", new Object[]{feature}), OpType.CATEGORICAL, DataType.DOUBLE, (Expression)expression);
        feature = new BinnedFeature((PMMLEncoder)encoder, derivedField, categories, predicates);
        return Collections.singletonList(feature);
    }

    private Discretize encodeNumericalBinning(Feature feature, List<Number> splits, List<Double> categoriesOut, PredicateManager predicateManager, List<Predicate> predicates) {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        Discretize discretize = new Discretize(continuousFeature.getName()).setMapMissingTo((Object)CATEGORY_MISSING);
        for (int i = 0; i <= splits.size(); ++i) {
            Number leftMargin = null;
            Number rightMargin = null;
            if (i == 0) {
                rightMargin = splits.get(i);
            } else if (i == splits.size()) {
                leftMargin = splits.get(i - 1);
            } else {
                leftMargin = splits.get(i - 1);
                rightMargin = splits.get(i);
            }
            Interval interval = new Interval(Interval.Closure.CLOSED_OPEN, leftMargin, rightMargin);
            DiscretizeBin discretizeBin = new DiscretizeBin((Object)categoriesOut.get(i), interval);
            discretize.addDiscretizeBins(new DiscretizeBin[]{discretizeBin});
            Predicate leftPredicate = null;
            Predicate rightPredicate = null;
            if (leftMargin != null) {
                leftPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, (Object)leftMargin);
            }
            if (rightMargin != null) {
                rightPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_THAN, (Object)rightMargin);
            }
            Predicate predicate = leftPredicate != null && rightPredicate != null ? predicateManager.createCompoundPredicate(CompoundPredicate.BooleanOperator.AND, new Predicate[]{leftPredicate, rightPredicate}) : (leftPredicate != null ? leftPredicate : rightPredicate);
            predicates.add(predicate);
        }
        return discretize;
    }

    private MapValues encodeCategoricalBinning(Feature feature, List<Number> splits, List<?> categoriesIn, List<Double> categoriesOut, PredicateManager predicateManager, List<Predicate> predicates) {
        ArrayList inputValues = new ArrayList();
        ArrayList<Double> outputValues = new ArrayList<Double>();
        int begin = 0;
        for (int i = 0; i <= splits.size(); ++i) {
            int end;
            Double splitCategoryOut = categoriesOut.get(i);
            if (i < splits.size()) {
                Number split = splits.get(i);
                end = (int)Math.ceil(split.doubleValue());
            } else {
                end = categoriesIn.size();
            }
            List<?> splitCategoriesIn = categoriesIn.subList(begin, end);
            for (Object splitCategoryIn : splitCategoriesIn) {
                inputValues.add(splitCategoryIn);
                outputValues.add(splitCategoryOut);
            }
            begin = end;
            Predicate predicate = predicateManager.createPredicate(feature, splitCategoriesIn);
            predicates.add(predicate);
        }
        MapValues mapValues = ExpressionUtil.createMapValues((String)feature.getName(), inputValues, outputValues).setMapMissingTo((Object)CATEGORY_MISSING);
        return mapValues;
    }

    public List<Object> getCategoriesIn() {
        return this.getObjectArray("_categories");
    }

    public List<Double> getCategoriesOut() {
        String metric = this.getMetric();
        List<Integer> numberOfEvents = this.getNumberOfEvents();
        List<Integer> numberOfNonEvents = this.getNumberOfNonEvents();
        ClassDictUtil.checkSize((Collection[])new Collection[]{numberOfEvents, numberOfNonEvents});
        double constant = Math.log((double)OptimalBinningUtil.sumExact(numberOfEvents) / (double)OptimalBinningUtil.sumExact(numberOfNonEvents));
        ArrayList<Double> result = new ArrayList<Double>();
        block8: for (int i = 0; i < numberOfEvents.size(); ++i) {
            double eventRate = (double)numberOfEvents.get(i).intValue() / (double)Math.addExact(numberOfEvents.get(i), numberOfNonEvents.get(i));
            switch (metric) {
                case "event_rate": {
                    result.add(eventRate);
                    continue block8;
                }
                case "woe": {
                    double woe = Math.log(1.0 / eventRate - 1.0) + constant;
                    result.add(woe);
                    continue block8;
                }
                default: {
                    throw new IllegalArgumentException(metric);
                }
            }
        }
        return result;
    }

    public String getDType() {
        return (String)this.getEnum("dtype", arg_0 -> ((OptimalBinning)this).getString(arg_0), Arrays.asList(DTYPE_CATEGORICAL, DTYPE_NUMERICAL));
    }

    public String getDefaultMetric() {
        return METRIC_WOE;
    }

    public Collection<String> getSupportedMetrics() {
        return Arrays.asList(METRIC_EVENT_RATE, METRIC_WOE);
    }

    public String getMetric() {
        if (!this.hasattr("metric")) {
            return this.getDefaultMetric();
        }
        return (String)this.getEnum("metric", arg_0 -> ((OptimalBinning)this).getString(arg_0), this.getSupportedMetrics());
    }

    public OptimalBinning setMetric(String metric) {
        this.setattr("metric", metric);
        return this;
    }

    public List<Integer> getNumberOfEvents() {
        return this.getIntegerArray("_n_event");
    }

    public List<Integer> getNumberOfNonEvents() {
        return this.getIntegerArray("_n_nonevent");
    }

    public List<Number> getSpecialCodes() {
        Object specialCodes = this.getOptionalObject("special_codes");
        if (specialCodes == null) {
            return Collections.emptyList();
        }
        return this.getListLike("special_codes", Number.class);
    }

    public List<Number> getSplitsOptimal() {
        return this.getNumberArray("_splits_optimal");
    }
}

