/*
 * Decompiled with CFR 0.152.
 */
package sklearn;

import com.google.common.collect.Iterables;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.Schema;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Estimator;
import sklearn.HasApplyField;
import sklearn.HasClasses;
import sklearn.HasDecisionFunctionField;
import sklearn.HasMultiApplyField;
import sklearn.HasPredictField;

public class EstimatorUtil {
    private EstimatorUtil() {
    }

    public static MiningFunction getMiningFunction(List<? extends Estimator> estimators) {
        Set miningFunctions = estimators.stream().map(estimator -> estimator.getMiningFunction()).collect(Collectors.toSet());
        if (miningFunctions.size() == 1) {
            return (MiningFunction)Iterables.getOnlyElement(miningFunctions);
        }
        return MiningFunction.MIXED;
    }

    public static List<?> getClasses(Estimator estimator) {
        if (estimator instanceof HasClasses) {
            HasClasses hasClasses = (HasClasses)((Object)estimator);
            return hasClasses.getClasses();
        }
        throw new IllegalArgumentException("The estimator object (" + ClassDictUtil.formatClass((Object)estimator) + ") is not a classifier");
    }

    public static List<Feature> export(Estimator estimator, String predictFunc, Schema schema, Model model, SkLearnEncoder encoder) {
        switch (predictFunc) {
            case "apply": {
                if (estimator instanceof HasMultiApplyField) {
                    HasMultiApplyField hasMultiApplyField = (HasMultiApplyField)((Object)estimator);
                    return encoder.export(model, hasMultiApplyField.getApplyFields());
                }
                if (estimator instanceof HasApplyField) {
                    HasApplyField hasApplyField = (HasApplyField)((Object)estimator);
                    return encoder.export(model, hasApplyField.getApplyField());
                }
                throw new IllegalArgumentException();
            }
            case "decision_function": {
                if (estimator instanceof HasDecisionFunctionField) {
                    HasDecisionFunctionField hasDecisionFunctionField = (HasDecisionFunctionField)((Object)estimator);
                    return encoder.export(model, hasDecisionFunctionField.getDecisionFunctionField());
                }
                throw new IllegalArgumentException();
            }
            case "predict": {
                if (estimator instanceof HasPredictField) {
                    HasPredictField hasPredictField = (HasPredictField)((Object)estimator);
                    return encoder.export(model, hasPredictField.getPredictField());
                }
                if (estimator.isSupervised()) {
                    ScalarLabel scalarLabel = (ScalarLabel)schema.getLabel();
                    MiningFunction miningFunction = estimator.getMiningFunction();
                    switch (miningFunction) {
                        case CLASSIFICATION: 
                        case REGRESSION: {
                            Feature feature = encoder.exportPrediction(model, scalarLabel);
                            return Collections.singletonList(feature);
                        }
                    }
                    throw new IllegalArgumentException();
                }
                Output output = model.getOutput();
                if (output != null && output.hasOutputFields()) {
                    List outputFields = output.getOutputFields();
                    List predictionOutputFields = outputFields.stream().filter(outputField -> SkLearnEncoder.isPrediction(outputField)).collect(Collectors.toList());
                    if (predictionOutputFields.isEmpty()) {
                        throw new IllegalArgumentException();
                    }
                    OutputField outputField2 = (OutputField)Iterables.getLast(predictionOutputFields);
                    return encoder.export(model, outputField2.getName());
                }
                throw new IllegalArgumentException();
            }
            case "predict_proba": {
                if (estimator instanceof HasClasses) {
                    HasClasses hasClasses = (HasClasses)((Object)estimator);
                    CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
                    List<OutputField> probabilityOutputFields = estimator.createPredictProbaFields(DataType.DOUBLE, categoricalLabel);
                    List<String> names = probabilityOutputFields.stream().map(outputField -> outputField.requireName()).collect(Collectors.toList());
                    return encoder.export(model, names);
                }
                throw new IllegalArgumentException();
            }
        }
        throw new IllegalArgumentException(predictFunc);
    }
}

