/*
 * Decompiled with CFR 0.152.
 */
package tpot.builtins;

import java.util.ArrayList;
import java.util.List;
import net.razorvine.pickle.objects.ClassDict;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.Transformation;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Classifier;
import sklearn.Estimator;
import sklearn.EstimatorUtil;
import sklearn.HasEstimator;
import sklearn.Transformer;
import sklearn.TypeUtil;
import tpot.builtins.CategoricalOutputFeature;
import tpot.builtins.ContinuousOutputFeature;

public class StackingEstimator
extends Transformer
implements HasEstimator<Estimator> {
    public StackingEstimator(String module, String name) {
        super(module, name);
    }

    @Override
    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        Output output;
        ContinuousLabel label;
        Estimator estimator = this.getEstimator();
        ArrayList<Feature> result = new ArrayList<Feature>();
        FieldName name = FieldName.create((String)("stack(" + features.size() + ")"));
        MiningFunction miningFunction = estimator.getMiningFunction();
        switch (miningFunction) {
            case CLASSIFICATION: {
                Classifier classifier = EstimatorUtil.asClassifier(estimator);
                List<?> classes = EstimatorUtil.getClasses(estimator);
                DataType dataType = TypeUtil.getDataType(classes, DataType.STRING);
                List<String> categories = EstimatorUtil.formatTargetCategories(classes);
                label = new CategoricalLabel(null, dataType, categories);
                output = ModelUtil.createPredictedOutput((FieldName)name, (OpType)OpType.CATEGORICAL, (DataType)label.getDataType(), (Transformation[])new Transformation[0]);
                if (classifier.hasProbabilityDistribution()) {
                    for (String category : categories) {
                        OutputField outputField = ModelUtil.createProbabilityField((DataType)DataType.DOUBLE, (String)category).setName(FieldName.create((String)("probability(" + name.getValue() + ", " + category + ")"))).setFinalResult(Boolean.valueOf(false));
                        output.addOutputFields(new OutputField[]{outputField});
                        result.add((Feature)new ContinuousOutputFeature((PMMLEncoder)encoder, output, outputField));
                    }
                }
                result.add((Feature)new CategoricalOutputFeature((PMMLEncoder)encoder, output, name, label.getDataType(), categories));
                break;
            }
            case REGRESSION: {
                label = new ContinuousLabel(null, DataType.DOUBLE);
                output = ModelUtil.createPredictedOutput((FieldName)name, (OpType)OpType.CONTINUOUS, (DataType)label.getDataType(), (Transformation[])new Transformation[0]);
                result.add((Feature)new ContinuousOutputFeature((PMMLEncoder)encoder, output, name, label.getDataType()));
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        Schema schema = new Schema((Label)label, features);
        Model model = estimator.encodeModel(schema, encoder).setOutput(output);
        encoder.addTransformer(model);
        result.addAll(features);
        return result;
    }

    @Override
    public Estimator getEstimator() {
        ClassDict estimator = (ClassDict)super.get((Object)"estimator");
        return EstimatorUtil.asEstimator(estimator);
    }
}

