/*
 * Decompiled with CFR 0.152.
 */
package sklearn.multiclass;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import sklearn.Classifier;
import sklearn.HasEstimatorEnsemble;
import sklearn.StepUtil;

public class OneVsRestClassifier
extends Classifier
implements HasEstimatorEnsemble<Classifier> {
    public OneVsRestClassifier(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        int numberOfFeatures = super.getNumberOfFeatures();
        if (numberOfFeatures != -1) {
            return numberOfFeatures;
        }
        List<? extends Classifier> estimators = this.getEstimators();
        return StepUtil.getNumberOfFeatures(estimators);
    }

    @Override
    public Model encodeModel(Schema schema) {
        List<? extends Classifier> estimators = this.getEstimators();
        Boolean multilabel = this.getMultilabel();
        if (multilabel.booleanValue()) {
            throw new IllegalArgumentException();
        }
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        if (estimators.size() == 1) {
            SchemaUtil.checkSize((int)2, (CategoricalLabel)categoricalLabel);
            Classifier estimator = estimators.get(0);
            if (!estimator.hasProbabilityDistribution()) {
                throw new IllegalArgumentException();
            }
            return estimator.encode(schema);
        }
        if (estimators.size() >= 2) {
            SchemaUtil.checkSize((int)estimators.size(), (CategoricalLabel)categoricalLabel);
            ArrayList<Model> models = new ArrayList<Model>();
            for (int i = 0; i < estimators.size(); ++i) {
                Classifier estimator = estimators.get(i);
                if (!estimator.hasProbabilityDistribution()) {
                    throw new IllegalArgumentException();
                }
                Output output = new Output().addOutputFields(new OutputField[]{ModelUtil.createProbabilityField((String)FieldNameUtil.create((String)"decisionFunction", (Object[])new Object[]{categoricalLabel.getValue(i)}), (DataType)DataType.DOUBLE, (Object)categoricalLabel.getValue(i))});
                CategoricalLabel segmentCategoricalLabel = new CategoricalLabel(DataType.STRING, Arrays.asList("(other)", ValueUtil.asString((Object)categoricalLabel.getValue(i))));
                Schema segmentSchema = schema.toRelabeledSchema((Label)segmentCategoricalLabel);
                Model model = estimator.encode(segmentSchema).setOutput(output);
                models.add(model);
            }
            return MiningModelUtil.createClassification(models, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.SIMPLEMAX, (boolean)true, (Schema)schema);
        }
        throw new IllegalArgumentException();
    }

    @Override
    public List<? extends Classifier> getEstimators() {
        return this.getList("estimators_", Classifier.class);
    }

    public Boolean getMultilabel() {
        return this.getOptionalBoolean("multilabel_", Boolean.FALSE);
    }
}

