/*
 * Decompiled with CFR 0.152.
 */
package sklearn.multiclass;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import sklearn.Classifier;
import sklearn.ClassifierUtil;
import sklearn.EstimatorUtil;
import sklearn.HasEstimatorEnsemble;

public class OneVsRestClassifier
extends Classifier
implements HasEstimatorEnsemble<Classifier> {
    public OneVsRestClassifier(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        return EstimatorUtil.getNumberOfFeatures(this);
    }

    @Override
    public Model encodeModel(Schema schema) {
        List<? extends Classifier> estimators = this.getEstimators();
        Boolean multilabel = this.getMultilabel();
        if (multilabel != null && multilabel.booleanValue()) {
            throw new IllegalArgumentException();
        }
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        if (estimators.size() == 1) {
            ClassifierUtil.checkSize(2, categoricalLabel);
            Classifier estimator = estimators.get(0);
            if (!estimator.hasProbabilityDistribution()) {
                throw new IllegalArgumentException();
            }
            return estimator.encodeModel(schema);
        }
        if (estimators.size() >= 2) {
            ClassifierUtil.checkSize(estimators.size(), categoricalLabel);
            ArrayList<Model> models = new ArrayList<Model>();
            for (int i = 0; i < estimators.size(); ++i) {
                Classifier estimator = estimators.get(i);
                if (!estimator.hasProbabilityDistribution()) {
                    throw new IllegalArgumentException();
                }
                Output output = new Output().addOutputFields(new OutputField[]{ModelUtil.createProbabilityField((FieldName)FieldName.create((String)("decisionFunction(" + categoricalLabel.getValue(i) + ")")), (DataType)DataType.DOUBLE, (String)categoricalLabel.getValue(i))});
                Schema segmentSchema = new Schema((Label)new CategoricalLabel(null, DataType.STRING, Arrays.asList("(other)", categoricalLabel.getValue(i))), schema.getFeatures());
                Model model = estimator.encodeModel(segmentSchema).setOutput(output);
                models.add(model);
            }
            return MiningModelUtil.createClassification(models, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.SIMPLEMAX, (boolean)true, (Schema)schema);
        }
        throw new IllegalArgumentException();
    }

    @Override
    public List<? extends Classifier> getEstimators() {
        return this.getList("estimators_", Classifier.class);
    }

    public Boolean getMultilabel() {
        return (Boolean)this.get("multilabel_");
    }
}

