/*
 * Decompiled with CFR 0.152.
 */
package sklego.meta;

import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.TypeUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Calibrator;
import sklearn.Estimator;
import sklearn.EstimatorUtil;
import sklearn.HasEstimator;
import sklearn.Transformer;
import sklearn.tree.HasTreeOptions;

public class EstimatorTransformer
extends Transformer
implements HasEstimator<Estimator> {
    public EstimatorTransformer(String module, String name) {
        super(module, name);
    }

    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        if (!encoder.hasModel()) {
            return this.encodePreProcessor(features, encoder);
        }
        return this.encodePostProcessor(features, encoder);
    }

    private List<Feature> encodePreProcessor(List<Feature> features, SkLearnEncoder encoder) {
        String predictFunc;
        Estimator estimator = this.getEstimator();
        switch (predictFunc = this.getPredictFunc()) {
            case "apply": 
            case "decision_function": 
            case "predict": 
            case "predict_proba": {
                break;
            }
            default: {
                throw new IllegalArgumentException(predictFunc);
            }
        }
        Schema schema = EstimatorTransformer.createSchema(estimator, features, encoder);
        switch (predictFunc) {
            case "apply": {
                if (!(estimator instanceof HasTreeOptions)) break;
                HasTreeOptions hasTreeOptions = (HasTreeOptions)estimator;
                estimator.putOption("winner_id", (Object)Boolean.TRUE);
                break;
            }
        }
        Model model = estimator.encode(schema);
        encoder.addTransformer(model);
        List result = EstimatorUtil.export((Estimator)estimator, (String)predictFunc, (Schema)schema, (Model)model, (SkLearnEncoder)encoder);
        Output output = model.getOutput();
        if (output != null && output.hasOutputFields()) {
            List outputFields = output.getOutputFields();
            outputFields.clear();
        }
        return result;
    }

    private List<Feature> encodePostProcessor(List<Feature> features, SkLearnEncoder encoder) {
        String predictFunc;
        Calibrator estimator = this.getEstimator(Calibrator.class);
        switch (predictFunc = this.getPredictFunc()) {
            case "predict": {
                break;
            }
            default: {
                throw new IllegalArgumentException(predictFunc);
            }
        }
        Model model = encoder.getModel();
        features = estimator.encodeFeatures(features, encoder);
        ContinuousFeature continuousFeature = (ContinuousFeature)Iterables.getOnlyElement((Iterable)features);
        DerivedField derivedField = (DerivedField)continuousFeature.getField();
        String name = derivedField.requireName();
        encoder.removeDerivedField(name);
        OutputField outputField = new OutputField(name, derivedField.requireOpType(), derivedField.requireDataType()).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(derivedField.requireExpression());
        Output output = ModelUtil.ensureOutput((Model)model);
        output.addOutputFields(new OutputField[]{outputField});
        return encoder.export(model, name);
    }

    public Estimator getEstimator() {
        return this.getEstimator(Estimator.class);
    }

    public <E extends Estimator> E getEstimator(Class<? extends E> clazz) {
        return (E)((Estimator)this.get("estimator_", clazz));
    }

    public String getPredictFunc() {
        return this.getString("predict_func");
    }

    private static Schema createSchema(Estimator estimator, List<Feature> features, SkLearnEncoder encoder) {
        ContinuousLabel label = null;
        if (estimator.isSupervised()) {
            MiningFunction miningFunction = estimator.getMiningFunction();
            switch (miningFunction) {
                case CLASSIFICATION: {
                    List categories = EstimatorUtil.getClasses((Estimator)estimator);
                    DataType dataType = TypeUtil.getDataType((Collection)categories, (DataType)DataType.STRING);
                    label = new CategoricalLabel(dataType, categories);
                    break;
                }
                case REGRESSION: {
                    label = new ContinuousLabel(DataType.DOUBLE);
                    break;
                }
                default: {
                    throw new IllegalArgumentException();
                }
            }
        }
        return new Schema((PMMLEncoder)encoder, label, features);
    }
}

