/*
 * Decompiled with CFR 0.152.
 */
package sklearn.ensemble.hist_gradient_boosting;

import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.TypeInfo;
import org.jpmml.sklearn.SkLearnEncoder;
import org.jpmml.sklearn.SkLearnException;
import sklearn.Transformer;
import sklearn.compose.ColumnTransformer;
import sklearn.ensemble.hist_gradient_boosting.BinMapper;
import sklearn.ensemble.hist_gradient_boosting.TreePredictor;
import sklearn.ensemble.hist_gradient_boosting.TreePredictorUtil;
import sklearn.preprocessing.OrdinalEncoder;

public class HistGradientBoostingUtil {
    private HistGradientBoostingUtil() {
    }

    public static Schema preprocess(final ColumnTransformer preprocessor, Schema schema) {
        final SkLearnEncoder encoder = (SkLearnEncoder)schema.getEncoder();
        Label label = schema.getLabel();
        final ArrayList<Feature> features = new ArrayList<Feature>(schema.getFeatures());
        ColumnTransformer filterPreprocessor = new ColumnTransformer(preprocessor.getPythonModule(), preprocessor.getPythonName()){
            {
                super(module, name);
                this.update((Map)((Object)preprocessor));
            }

            @Override
            public List<Object[]> getFittedTransformers() {
                final List<Object[]> fittedTransformers = super.getFittedTransformers();
                List names = fittedTransformers.stream().map(fittedTransformer -> ColumnTransformer.getName(fittedTransformer)).collect(Collectors.toList());
                if (!Objects.equals(names, Arrays.asList("encoder", "numerical"))) {
                    if (!Objects.equals(names, Arrays.asList("encoder", "numerical", "remainder"))) {
                        throw new SkLearnException("Expected [encoder, numerical, remainder] as transformer names, got " + names + " as transformer names");
                    }
                }
                AbstractList<Object[]> result = new AbstractList<Object[]>(){

                    @Override
                    public int size() {
                        return fittedTransformers.size();
                    }

                    @Override
                    public Object[] get(int index) {
                        Object[] fittedTransformer = (Object[])fittedTransformers.get(index);
                        Transformer transformer = 1.getTransformer(fittedTransformer);
                        if (transformer instanceof OrdinalEncoder) {
                            final OrdinalEncoder ordinalEncoder = (OrdinalEncoder)transformer;
                            final List rowFeatures = 1.getFeatures(fittedTransformer, features, encoder);
                            OrdinalEncoder filterOrdinalEncoder = new OrdinalEncoder(ordinalEncoder.getPythonModule(), ordinalEncoder.getPythonName()){
                                {
                                    super(module, name);
                                    this.update((Map)((Object)ordinalEncoder));
                                }

                                @Override
                                public List<List<Object>> getCategories() {
                                    final List<List<Object>> categories = super.getCategories();
                                    ClassDictUtil.checkSize((Collection[])new Collection[]{categories, rowFeatures});
                                    AbstractList<List<Object>> result = new AbstractList<List<Object>>(){

                                        @Override
                                        public int size() {
                                            return categories.size();
                                        }

                                        @Override
                                        public List<Object> get(int index) {
                                            Feature rowFeature = (Feature)rowFeatures.get(index);
                                            if (rowFeature instanceof CategoricalFeature) {
                                                CategoricalFeature categoricalFeature = (CategoricalFeature)rowFeature;
                                                return categoricalFeature.getValues();
                                            }
                                            return (List)categories.get(index);
                                        }
                                    };
                                    return result;
                                }

                                @Override
                                public TypeInfo getDType() {
                                    TypeInfo result = new TypeInfo(){

                                        public DataType getDataType() {
                                            return DataType.INTEGER;
                                        }
                                    };
                                    return result;
                                }

                                @Override
                                public String createFieldName(String function, List<?> args) {
                                    return super.createFieldName("hist_" + function, args);
                                }
                            };
                            1.setTransformer(fittedTransformer, filterOrdinalEncoder);
                        }
                        return fittedTransformer;
                    }
                };
                return result;
            }
        };
        List<Feature> filterFeatures = filterPreprocessor.encode(features, encoder);
        return new Schema((ModelEncoder)encoder, label, filterFeatures);
    }

    public static MiningModel encodeHistGradientBoosting(List<List<TreePredictor>> predictors, BinMapper binMapper, List<? extends Number> baselinePredictions, int column, Schema schema) {
        List<TreePredictor> treePredictors = predictors.stream().map(predictor -> (TreePredictor)((Object)((Object)predictor.get(column)))).collect(Collectors.toList());
        Number baselinePrediction = baselinePredictions.get(column);
        return HistGradientBoostingUtil.encodeHistGradientBoosting(treePredictors, binMapper, baselinePrediction, schema);
    }

    public static MiningModel encodeHistGradientBoosting(List<TreePredictor> treePredictors, BinMapper binMapper, Number baselinePrediction, Schema schema) {
        ContinuousLabel continuousLabel = (ContinuousLabel)schema.getLabel();
        PredicateManager predicateManager = new PredicateManager();
        Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);
        ArrayList<TreeModel> treeModels = new ArrayList<TreeModel>();
        for (TreePredictor treePredictor : treePredictors) {
            TreeModel treeModel = TreePredictorUtil.encodeTreeModel(treePredictor, binMapper, predicateManager, segmentSchema);
            treeModels.add(treeModel);
        }
        MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)continuousLabel)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.SUM, (Segmentation.MissingPredictionTreatment)Segmentation.MissingPredictionTreatment.RETURN_MISSING, treeModels)).setTargets(ModelUtil.createRescaleTargets(null, (Number)baselinePrediction, (ContinuousLabel)continuousLabel));
        return miningModel;
    }
}

