/*
 * Decompiled with CFR 0.152.
 */
package sklearn.ensemble.stacking;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.Model;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Schema;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Classifier;
import sklearn.HasEstimatorEnsemble;
import sklearn.StepUtil;
import sklearn.ensemble.stacking.StackingUtil;

public class StackingClassifier
extends Classifier
implements HasEstimatorEnsemble<Classifier> {
    public StackingClassifier(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        List<? extends Classifier> estimators = this.getEstimators();
        return StepUtil.getNumberOfFeatures(estimators);
    }

    public MiningModel encodeModel(Schema schema) {
        List<? extends Classifier> estimators = this.getEstimators();
        Classifier finalEstimator = this.getFinalEstimator();
        Boolean passthrough = this.getPassthrough();
        List<String> stackMethod = this.getStackMethod();
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        final List<Object> values = categoricalLabel.size() == 2 ? Collections.singletonList(categoricalLabel.getValue(1)) : categoricalLabel.getValues();
        StackingUtil.PredictFunction predictFunction = new StackingUtil.PredictFunction(){

            @Override
            public List<Feature> apply(int index, Model model, String stackMethod, SkLearnEncoder encoder) {
                switch (stackMethod) {
                    case "predict_proba": {
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException(stackMethod);
                    }
                }
                ArrayList<Feature> result = new ArrayList<Feature>();
                for (Object value : values) {
                    Feature feature = encoder.exportProbability(model, FieldNameUtil.create((String)stackMethod, (Object[])new Object[]{index, value}), value);
                    result.add(feature);
                }
                return result;
            }
        };
        return StackingUtil.encodeStacking(estimators, stackMethod, predictFunction, finalEstimator, passthrough, schema);
    }

    @Override
    public List<? extends Classifier> getEstimators() {
        return this.getList("estimators_", Classifier.class);
    }

    public Classifier getFinalEstimator() {
        return (Classifier)this.get("final_estimator_", Classifier.class);
    }

    public Boolean getPassthrough() {
        return this.getBoolean("passthrough");
    }

    public List<String> getStackMethod() {
        return this.getList("stack_method_", String.class);
    }
}

