/*
 * Decompiled with CFR 0.152.
 */
package sklearn.calibration;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Field;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Calibrator;
import sklearn.Classifier;
import sklearn.EstimatorUtil;
import sklearn.HasEstimator;
import sklearn.SkLearnClassifier;
import sklearn.ensemble.gradient_boosting.GradientBoostingClassifier;
import sklearn.linear_model.LinearClassifier;

public class CalibratedClassifier
extends SkLearnClassifier
implements HasEstimator<Classifier> {
    private static final String METHOD_ISOTONIC = "isotonic";
    private static final String METHOD_SIGMOID = "sigmoid";

    public CalibratedClassifier(String module, String name) {
        super(module, name);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public Model encodeModel(Schema schema) {
        RegressionModel calibratorModel;
        List<RegressionTable> regressionTables;
        List<Calibrator> calibrators = this.getCalibrators();
        List<?> classes = this.getClasses();
        Classifier estimator = this.getEstimator();
        String method = this.getMethod();
        SkLearnEncoder encoder = (SkLearnEncoder)schema.getEncoder();
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        List features = schema.getFeatures();
        Model model = estimator.encode(schema);
        ArrayList<Object> models = new ArrayList<Object>();
        List<ContinuousFeature> decisionFunctionFeatures = new ArrayList();
        if (estimator instanceof GradientBoostingClassifier || estimator instanceof LinearClassifier) {
            Model decisionFunctionModel;
            if (model instanceof MiningModel) {
                MiningModel miningModel = (MiningModel)model;
                Segmentation segmentation = miningModel.requireSegmentation();
                Iterator segments = segmentation.requireSegments();
                List regressorSegments = segments.subList(0, segments.size() - 1);
                for (Object regressorSegment : regressorSegments) {
                    OutputField outputField;
                    decisionFunctionModel = regressorSegment.requireModel();
                    if (decisionFunctionModel.requireMiningFunction() != MiningFunction.REGRESSION) {
                        throw new IllegalArgumentException();
                    }
                    Output output = decisionFunctionModel.getOutput();
                    if (output == null || !output.hasOutputFields()) {
                        throw new IllegalArgumentException();
                    }
                    List outputFields = output.getOutputFields();
                    if (estimator instanceof LinearClassifier) {
                        RegressionModel regressionModel = (RegressionModel)decisionFunctionModel;
                        regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.NONE);
                        outputField = (OutputField)Iterables.getOnlyElement((Iterable)outputFields);
                    } else {
                        if (!(estimator instanceof GradientBoostingClassifier)) throw new IllegalArgumentException();
                        if (outputFields.size() != 1) {
                            if (outputFields.size() != 2) throw new IllegalArgumentException();
                            outputFields.remove(1);
                        }
                        outputField = (OutputField)Iterables.getOnlyElement((Iterable)outputFields);
                    }
                    if (outputField.getResultFeature() != ResultFeature.PREDICTED_VALUE) {
                        throw new IllegalArgumentException();
                    }
                    outputField.setName(this.getDecisionFunctionField(outputField.requireName()));
                    models.add(decisionFunctionModel);
                    decisionFunctionFeatures.add(new ContinuousFeature((PMMLEncoder)encoder, (Field)outputField));
                }
                List<Segment> classifierSegments = Collections.singletonList((Segment)segments.get(segments.size() - 1));
                for (Segment classifierSegment : classifierSegments) {
                    RegressionModel normalizerModel = (RegressionModel)classifierSegment.requireModel();
                    if (normalizerModel.requireMiningFunction() == MiningFunction.CLASSIFICATION) continue;
                    throw new IllegalArgumentException();
                }
            } else {
                if (!(model instanceof RegressionModel)) throw new IllegalArgumentException();
                RegressionModel regressionModel = (RegressionModel)model;
                regressionTables = regressionModel.getRegressionTables();
                if (categoricalLabel.size() == 2) {
                    regressionTables = regressionTables.subList(0, 1);
                }
                for (RegressionTable regressionTable : regressionTables) {
                    OutputField outputField = ModelUtil.createPredictedField((String)this.getDecisionFunctionField(regressionTable.requireTargetCategory()), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE).setFinalResult(Boolean.valueOf(false));
                    Output output = new Output().addOutputFields(new OutputField[]{outputField});
                    regressionTable.setTargetCategory(null);
                    decisionFunctionModel = new RegressionModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(null), null).setNormalizationMethod(RegressionModel.NormalizationMethod.NONE).addRegressionTables(new RegressionTable[]{regressionTable}).setOutput(output);
                    models.add(decisionFunctionModel);
                    decisionFunctionFeatures.add(new ContinuousFeature((PMMLEncoder)encoder, (Field)outputField));
                }
            }
        } else {
            Output output = EstimatorUtil.getFinalOutput(model);
            if (output == null) {
                throw new IllegalArgumentException();
            }
            List outputFields = output.getOutputFields();
            for (OutputField outputField : outputFields) {
                ResultFeature resultFeature = outputField.getResultFeature();
                switch (resultFeature) {
                    case PROBABILITY: {
                        outputField.setName(this.getDecisionFunctionField(outputField.getValue()));
                        decisionFunctionFeatures.add(new ContinuousFeature((PMMLEncoder)encoder, (Field)outputField));
                        break;
                    }
                }
            }
            models.add(model);
            if (categoricalLabel.size() == 2) {
                SchemaUtil.checkSize((int)2, decisionFunctionFeatures);
                decisionFunctionFeatures = decisionFunctionFeatures.subList(1, 2);
            }
        }
        SchemaUtil.checkSize((int)calibrators.size(), decisionFunctionFeatures);
        if (calibrators.size() == 1) {
            SchemaUtil.checkSize((int)2, (DiscreteLabel)categoricalLabel);
            Calibrator calibrator = calibrators.get(0);
            Model featureModel = (Model)models.get(0);
            Feature feature = (Feature)decisionFunctionFeatures.get(0);
            Feature calibratedFeature = CalibratedClassifier.calibrate(calibrator, featureModel, feature, encoder);
            calibratorModel = RegressionModelUtil.createBinaryLogisticClassification(Collections.singletonList(calibratedFeature), Collections.singletonList(1.0), null, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.NONE, (boolean)false, (Schema)schema);
        } else {
            if (calibrators.size() < 3) throw new IllegalArgumentException();
            SchemaUtil.checkSize((int)calibrators.size(), (DiscreteLabel)categoricalLabel);
            regressionTables = new ArrayList();
            for (int i = 0; i < calibrators.size(); ++i) {
                Calibrator calibrator = calibrators.get(i);
                Model featureModel = models.size() == 1 ? (Model)models.get(0) : (Model)models.get(i);
                Feature feature = (Feature)decisionFunctionFeatures.get(i);
                Feature calibratedFeature = CalibratedClassifier.calibrate(calibrator, featureModel, feature, encoder);
                RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(Collections.singletonList(calibratedFeature), Collections.singletonList(1.0), null).setTargetCategory(categoricalLabel.getValue(i));
                regressionTables.add(regressionTable);
            }
            calibratorModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.SIMPLEMAX);
        }
        this.encodePredictProbaOutput((Model)calibratorModel, DataType.DOUBLE, (DiscreteLabel)categoricalLabel);
        models.add(calibratorModel);
        return MiningModelUtil.createModelChain(models, (Segmentation.MissingPredictionTreatment)Segmentation.MissingPredictionTreatment.RETURN_MISSING);
    }

    @Override
    public List<?> getClasses() {
        return this.getClasses("classes");
    }

    public List<Calibrator> getCalibrators() {
        return this.getList("calibrators", Calibrator.class);
    }

    @Override
    public Classifier getEstimator() {
        return (Classifier)this.get("estimator", Classifier.class);
    }

    public String getMethod() {
        return (String)this.getEnum("method", arg_0 -> ((CalibratedClassifier)this).getString(arg_0), Arrays.asList(METHOD_ISOTONIC, METHOD_SIGMOID));
    }

    private String getDecisionFunctionField(Object value) {
        Object pmmlSegmentId = this.getPMMLSegmentId();
        if (value instanceof String) {
            String name = (String)value;
            value = CalibratedClassifier.extractArguments("decisionFunction", name);
        }
        List<Object> args = pmmlSegmentId != null ? Arrays.asList(pmmlSegmentId, value) : Arrays.asList(value);
        return FieldNameUtil.create((String)"decisionFunction", args);
    }

    private static Feature calibrate(Calibrator calibrator, Model model, Feature feature, SkLearnEncoder encoder) {
        encoder.export(model, feature.getName());
        Feature calibratedFeature = (Feature)Iterables.getOnlyElement(calibrator.encodeFeatures(Collections.singletonList(feature), encoder));
        DerivedField derivedField = encoder.removeDerivedField(calibratedFeature.getName());
        OutputField outputField = new OutputField(derivedField.requireName(), derivedField.requireOpType(), derivedField.requireDataType()).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(derivedField.requireExpression()).setFinalResult(Boolean.valueOf(false));
        derivedField = encoder.createDerivedField(model, outputField, true);
        return new ContinuousFeature((PMMLEncoder)encoder, (Field)derivedField);
    }
}

