/*
 * Decompiled with CFR 0.152.
 */
package sklearn.preprocessing;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Discretize;
import org.dmg.pmml.DiscretizeBin;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.Interval;
import org.dmg.pmml.OpType;
import org.dmg.pmml.ParameterField;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.IndexFeature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.HasArray;
import org.jpmml.python.TypeInfo;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.SkLearnTransformer;

public class KBinsDiscretizer
extends SkLearnTransformer {
    private static final String ENCODE_ONEHOT = "onehot";
    private static final String ENCODE_ONEHOT_DENSE = "onehot-dense";
    private static final String ENCODE_ORDINAL = "ordinal";

    public KBinsDiscretizer(String module, String name) {
        super(module, name);
    }

    @Override
    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        TypeInfo dtype = this.getDType();
        String encode = this.getEncode();
        List<Integer> numberOfBins = this.getNumberOfBins();
        List<List<Number>> binEdges = this.getBinEdges();
        ClassDictUtil.checkSize((Collection[])new Collection[]{numberOfBins, binEdges, features});
        ArrayList<Feature> result = new ArrayList<Feature>();
        block9: for (int i = 0; i < features.size(); ++i) {
            Integer label;
            Feature feature = features.get(i);
            List<Number> bins = binEdges.get(i);
            ClassDictUtil.checkSize((int)(numberOfBins.get(i) + 1), (Collection[])new Collection[]{bins});
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            continuousFeature = KBinsDiscretizer.addEps(continuousFeature, encoder);
            ArrayList<Integer> labelCategories = new ArrayList<Integer>();
            Discretize discretize = new Discretize(continuousFeature.getName()).setDataType(dtype != null ? dtype.getDataType() : continuousFeature.getDataType());
            for (int j = 0; j < bins.size() - 1; ++j) {
                Number leftMargin = j > 0 ? (Number)bins.get(j) : (Number)null;
                Number rightMargin = j < bins.size() - 1 - 1 ? (Number)bins.get(j + 1) : (Number)null;
                Interval interval = new Interval(Interval.Closure.CLOSED_OPEN).setLeftMargin(leftMargin).setRightMargin(rightMargin);
                label = j;
                labelCategories.add(label);
                DiscretizeBin discretizeBin = new DiscretizeBin((Object)label, interval);
                discretize.addDiscretizeBins(new DiscretizeBin[]{discretizeBin});
            }
            DerivedField derivedField = encoder.createDerivedField(this.createFieldName("discretize", continuousFeature), OpType.CATEGORICAL, discretize.getDataType(), (Expression)discretize);
            switch (encode) {
                case "onehot": 
                case "onehot-dense": {
                    for (int j = 0; j < labelCategories.size(); ++j) {
                        label = (Integer)labelCategories.get(j);
                        result.add((Feature)new BinaryFeature((PMMLEncoder)encoder, (Field)derivedField, (Object)label));
                    }
                    continue block9;
                }
                case "ordinal": {
                    result.add((Feature)new IndexFeature((PMMLEncoder)encoder, (Field)derivedField, labelCategories));
                    continue block9;
                }
                default: {
                    throw new IllegalArgumentException(encode);
                }
            }
        }
        return result;
    }

    public TypeInfo getDType() {
        return this.getOptionalDType("dtype", false);
    }

    public String getEncode() {
        return (String)this.getEnum("encode", arg_0 -> ((KBinsDiscretizer)this).getString(arg_0), Arrays.asList(ENCODE_ONEHOT, ENCODE_ONEHOT_DENSE, ENCODE_ORDINAL));
    }

    public List<Integer> getNumberOfBins() {
        return this.getIntegerArray("n_bins_");
    }

    public List<List<Number>> getBinEdges() {
        List arrays = this.getArray("bin_edges_", HasArray.class);
        Function<HasArray, List<Number>> function = new Function<HasArray, List<Number>>(){

            public List<Number> apply(HasArray hasArray) {
                return hasArray.getArrayContent();
            }
        };
        return Lists.transform((List)arrays, (Function)function);
    }

    private static ContinuousFeature addEps(ContinuousFeature continuousFeature, SkLearnEncoder encoder) {
        DefineFunction defineFunction = encoder.getDefineFunction("add_eps");
        if (defineFunction == null) {
            defineFunction = KBinsDiscretizer.encodeDefineFunction("add_eps");
            encoder.addDefineFunction(defineFunction);
        }
        Apply apply = ExpressionUtil.createApply((DefineFunction)defineFunction, (Expression[])new Expression[]{continuousFeature.ref()});
        DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create((DefineFunction)defineFunction, (Object[])new Object[]{continuousFeature.getName()}), defineFunction.requireOpType(), defineFunction.requireDataType(), (Expression)apply);
        return new ContinuousFeature((PMMLEncoder)encoder, (Field)derivedField);
    }

    private static DefineFunction encodeDefineFunction(String name) {
        ParameterField valueField = new ParameterField("x");
        Double atol = 1.0E-8;
        Double rtol = 1.0E-5;
        Apply apply = ExpressionUtil.createApply((String)"+", (Expression[])new Expression[]{new FieldRef((Field)valueField), ExpressionUtil.createApply((String)"+", (Expression[])new Expression[]{ExpressionUtil.createConstant((Number)atol), ExpressionUtil.createApply((String)"*", (Expression[])new Expression[]{ExpressionUtil.createConstant((Number)rtol), ExpressionUtil.createApply((String)"abs", (Expression[])new Expression[]{new FieldRef((Field)valueField)})})})});
        DefineFunction defineFunction = new DefineFunction(name, OpType.CONTINUOUS, DataType.DOUBLE, null, (Expression)apply).addParameterFields(new ParameterField[]{valueField});
        return defineFunction;
    }
}

