/*
 * Decompiled with CFR 0.152.
 */
package tpot.builtins;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.Schema;
import org.jpmml.converter.TypeUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Classifier;
import sklearn.Estimator;
import sklearn.EstimatorUtil;
import sklearn.HasEstimator;
import sklearn.Transformer;

public class StackingEstimator
extends Transformer
implements HasEstimator<Estimator> {
    public StackingEstimator(String module, String name) {
        super(module, name);
    }

    public int getNumberOfFeatures() {
        Estimator estimator = this.getEstimator();
        return estimator.getNumberOfFeatures();
    }

    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        ContinuousLabel scalarLabel;
        Estimator estimator = this.getEstimator();
        MiningFunction miningFunction = estimator.getMiningFunction();
        switch (miningFunction) {
            case CLASSIFICATION: {
                List categories = EstimatorUtil.getClasses((Estimator)estimator);
                DataType dataType = TypeUtil.getDataType((Collection)categories, (DataType)DataType.STRING);
                scalarLabel = new CategoricalLabel(dataType, categories);
                break;
            }
            case REGRESSION: {
                scalarLabel = new ContinuousLabel(DataType.DOUBLE);
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        Schema schema = new Schema((PMMLEncoder)encoder, (Label)scalarLabel, features);
        Model model = estimator.encode(schema);
        encoder.addTransformer(model);
        String name = this.createFieldName("stack", features);
        ArrayList<Feature> result = new ArrayList<Feature>();
        switch (miningFunction) {
            case CLASSIFICATION: 
            case REGRESSION: {
                Feature feature = encoder.exportPrediction(model, name, (ScalarLabel)scalarLabel);
                result.add(feature);
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        switch (miningFunction) {
            case CLASSIFICATION: {
                Classifier classifier = (Classifier)estimator;
                if (!classifier.hasProbabilityDistribution()) break;
                List categories = EstimatorUtil.getClasses((Estimator)estimator);
                for (Object category : categories) {
                    Feature feature = encoder.exportProbability(model, FieldNameUtil.create((String)"probability", (Object[])new Object[]{name, category}), category);
                    result.add(feature);
                }
                break;
            }
            case REGRESSION: {
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        result.addAll(features);
        return result;
    }

    public Estimator getEstimator() {
        return (Estimator)this.get("estimator", Estimator.class);
    }
}

