/*
 * Decompiled with CFR 0.152.
 */
package sklearn.ensemble.voting_classifier;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.sklearn.ClassDictUtil;
import sklearn.Classifier;
import sklearn.Estimator;
import sklearn.EstimatorUtil;

public class VotingClassifier
extends Classifier {
    public VotingClassifier(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        List<? extends Classifier> estimators = this.getEstimators();
        Estimator estimator = estimators.get(0);
        return estimator.getNumberOfFeatures();
    }

    @Override
    public Set<DefineFunction> encodeDefineFunctions() {
        List<? extends Classifier> estimators = this.getEstimators();
        LinkedHashMap<String, DefineFunction> uniqueDefineFunctions = new LinkedHashMap<String, DefineFunction>();
        for (Classifier classifier : estimators) {
            Set<DefineFunction> defineFunctions = classifier.encodeDefineFunctions();
            for (DefineFunction defineFunction : defineFunctions) {
                uniqueDefineFunctions.put(defineFunction.getName(), defineFunction);
            }
        }
        LinkedHashSet<DefineFunction> result = new LinkedHashSet<DefineFunction>(uniqueDefineFunctions.values());
        return result;
    }

    @Override
    public Model encodeModel(Schema schema) {
        List<? extends Classifier> estimators = this.getEstimators();
        List<? extends Number> weights = this.getWeights();
        ArrayList<Model> models = new ArrayList<Model>();
        for (Classifier classifier : estimators) {
            Model model = classifier.encodeModel(schema);
            models.add(model);
        }
        String voting = this.getVoting();
        Segmentation.MultipleModelMethod multipleModelMethod = VotingClassifier.parseVoting(voting, weights != null && weights.size() > 0);
        MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Schema)schema)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)multipleModelMethod, models, weights)).setOutput(ModelUtil.createProbabilityOutput((Schema)schema));
        return miningModel;
    }

    public List<? extends Classifier> getEstimators() {
        List estimators = (List)this.get("estimators_");
        return EstimatorUtil.asClassifierList(estimators);
    }

    public String getVoting() {
        return (String)this.get("voting");
    }

    public List<? extends Number> getWeights() {
        Object weights = this.get("weights");
        if (weights == null || weights instanceof List) {
            return (List)weights;
        }
        return ClassDictUtil.getArray(this, "weights");
    }

    private static Segmentation.MultipleModelMethod parseVoting(String voting, boolean weighted) {
        switch (voting) {
            case "hard": {
                return weighted ? Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE : Segmentation.MultipleModelMethod.MAJORITY_VOTE;
            }
            case "soft": {
                return weighted ? Segmentation.MultipleModelMethod.WEIGHTED_AVERAGE : Segmentation.MultipleModelMethod.AVERAGE;
            }
        }
        throw new IllegalArgumentException(voting);
    }
}

