/*
 * Decompiled with CFR 0.152.
 */
package sklearn2pmml.ensemble;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.OrdinalLabel;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Classifier;

public class OrdinalClassifier
extends Classifier {
    public OrdinalClassifier(String module, String name) {
        super(module, name);
    }

    @Override
    public Model encodeModel(Schema schema) {
        List<? extends Classifier> estimators = this.getEstimators();
        SkLearnEncoder encoder = (SkLearnEncoder)schema.getEncoder();
        OrdinalLabel ordinalLabel = (OrdinalLabel)schema.getLabel();
        List features = schema.getFeatures();
        SchemaUtil.checkSize((int)(estimators.size() + 1), (DiscreteLabel)ordinalLabel);
        ArrayList<Object> models = new ArrayList<Object>();
        ArrayList<Feature> probabilityFeatures = new ArrayList<Feature>();
        for (int i = 0; i < estimators.size(); ++i) {
            Classifier estimator = estimators.get(i);
            if (!estimator.hasProbabilityDistribution()) {
                throw new IllegalArgumentException();
            }
            Object category = ordinalLabel.getValue(i);
            CategoricalLabel segmentLabel = new CategoricalLabel(DataType.DOUBLE, Arrays.asList("<=" + ValueUtil.asString((Object)category), ">" + ValueUtil.asString((Object)category)));
            Schema segmentSchema = schema.toRelabeledSchema((Label)segmentLabel);
            Model model = estimator.encode(segmentSchema);
            String name = FieldNameUtil.create((String)"probability", (Object[])new Object[]{segmentLabel.getValue(1)});
            List<Feature> segmentFeatures = encoder.export(model, name);
            if (segmentFeatures.size() != 1) {
                throw new IllegalArgumentException();
            }
            models.add(model);
            probabilityFeatures.addAll(segmentFeatures);
        }
        SchemaUtil.checkSize((int)estimators.size(), probabilityFeatures);
        ArrayList<RegressionTable> regressionTables = new ArrayList<RegressionTable>();
        for (int i = 0; i < estimators.size(); ++i) {
            RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(Collections.singletonList(probabilityFeatures.get(i)), Collections.singletonList(-1.0), (Number)1.0).setTargetCategory(ordinalLabel.getValue(i));
            regressionTables.add(regressionTable);
        }
        RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(Collections.emptyList(), Collections.emptyList(), (Number)1.0).setTargetCategory(ordinalLabel.getValue(estimators.size()));
        regressionTables.add(regressionTable);
        RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)ordinalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.NONE);
        this.encodePredictProbaOutput((Model)regressionModel, DataType.DOUBLE, (DiscreteLabel)ordinalLabel);
        models.add(regressionModel);
        return MiningModelUtil.createModelChain(models, (Segmentation.MissingPredictionTreatment)Segmentation.MissingPredictionTreatment.RETURN_MISSING);
    }

    @Override
    protected DiscreteLabel encodeLabel(String name, List<?> categories, SkLearnEncoder encoder) {
        return this.encodeLabel(name, OpType.ORDINAL, categories, encoder);
    }

    public Classifier getEstimator() {
        return (Classifier)this.get("estimator", Classifier.class);
    }

    public List<? extends Classifier> getEstimators() {
        return this.getList("estimators_", Classifier.class);
    }
}

