/*
 * Decompiled with CFR 0.152.
 */
package sklearn.ensemble.voting_classifier;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.FeatureMapper;
import sklearn.Classifier;
import sklearn.Estimator;
import sklearn.EstimatorUtil;

public class VotingClassifier
extends Classifier {
    private Map<List<?>, Schema> schemas = new HashMap();

    public VotingClassifier(String module, String name) {
        super(module, name);
    }

    @Override
    public Schema createSchema(FeatureMapper featureMapper) {
        Schema schema = super.createSchema(featureMapper);
        this.schemas.put(VotingClassifier.createSchemaKey(this), schema);
        List<? extends Classifier> estimators = this.getEstimators();
        for (Classifier classifier : estimators) {
            List<?> schemaKey = VotingClassifier.createSchemaKey(classifier);
            Schema estimatorSchema = this.schemas.get(schemaKey);
            if (estimatorSchema != null) continue;
            estimatorSchema = featureMapper.cast((OpType)schemaKey.get(0), (DataType)schemaKey.get(1), schema);
            this.schemas.put(schemaKey, estimatorSchema);
        }
        return schema;
    }

    @Override
    public int getNumberOfFeatures() {
        List<? extends Classifier> estimators = this.getEstimators();
        Estimator estimator = estimators.get(0);
        return estimator.getNumberOfFeatures();
    }

    @Override
    public boolean requiresContinuousInput() {
        return false;
    }

    @Override
    public Set<DefineFunction> encodeDefineFunctions() {
        LinkedHashMap<String, DefineFunction> uniqueDefineFunctions = new LinkedHashMap<String, DefineFunction>();
        List<? extends Classifier> estimators = this.getEstimators();
        for (Classifier classifier : estimators) {
            Set<DefineFunction> defineFunctions = classifier.encodeDefineFunctions();
            for (DefineFunction defineFunction : defineFunctions) {
                uniqueDefineFunctions.put(defineFunction.getName(), defineFunction);
            }
        }
        LinkedHashSet<DefineFunction> result = new LinkedHashSet<DefineFunction>(uniqueDefineFunctions.values());
        return result;
    }

    @Override
    public Model encodeModel(Schema schema) {
        List<? extends Classifier> estimators = this.getEstimators();
        List<? extends Number> weights = this.getWeights();
        ArrayList<Model> models = new ArrayList<Model>();
        for (Classifier classifier : estimators) {
            List<?> schemaKey = VotingClassifier.createSchemaKey(classifier);
            Schema estimatorSchema = this.schemas.get(schemaKey);
            if (estimatorSchema == null) {
                throw new IllegalArgumentException();
            }
            Model model = classifier.encodeModel(estimatorSchema);
            models.add(model);
        }
        String voting = this.getVoting();
        Segmentation.MultipleModelMethod multipleModelMethod = VotingClassifier.parseVoting(voting, weights != null && weights.size() > 0);
        MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Schema)schema)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)multipleModelMethod, models, weights)).setOutput(ModelUtil.createProbabilityOutput((Schema)schema));
        return miningModel;
    }

    public List<? extends Classifier> getEstimators() {
        List estimators = (List)this.get("estimators_");
        return EstimatorUtil.asClassifierList(estimators);
    }

    public String getVoting() {
        return (String)this.get("voting");
    }

    public List<? extends Number> getWeights() {
        Object weights = this.get("weights");
        if (weights == null || weights instanceof List) {
            return (List)weights;
        }
        return ClassDictUtil.getArray(this, "weights");
    }

    private static Segmentation.MultipleModelMethod parseVoting(String voting, boolean weighted) {
        switch (voting) {
            case "hard": {
                return weighted ? Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE : Segmentation.MultipleModelMethod.MAJORITY_VOTE;
            }
            case "soft": {
                return weighted ? Segmentation.MultipleModelMethod.WEIGHTED_AVERAGE : Segmentation.MultipleModelMethod.AVERAGE;
            }
        }
        throw new IllegalArgumentException(voting);
    }

    private static List<?> createSchemaKey(Estimator estimator) {
        List<Enum> result = Arrays.asList(estimator.requiresContinuousInput() ? OpType.CONTINUOUS : null, estimator.getDataType());
        return result;
    }
}

