/*
 * Decompiled with CFR 0.152.
 */
package sklego.meta;

import java.util.Collection;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.TypeUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Estimator;
import sklearn.EstimatorUtil;
import sklearn.HasEstimator;
import sklearn.Transformer;
import sklearn.tree.HasTreeOptions;

public class EstimatorTransformer
extends Transformer
implements HasEstimator<Estimator> {
    public EstimatorTransformer(String module, String name) {
        super(module, name);
    }

    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        String predictFunc;
        Estimator estimator = this.getEstimator();
        switch (predictFunc = this.getPredictFunc()) {
            case "apply": 
            case "decision_function": 
            case "predict": 
            case "predict_proba": {
                break;
            }
            default: {
                throw new IllegalArgumentException(predictFunc);
            }
        }
        Schema schema = EstimatorTransformer.createSchema(estimator, features, encoder);
        switch (predictFunc) {
            case "apply": {
                if (!(estimator instanceof HasTreeOptions)) break;
                HasTreeOptions hasTreeOptions = (HasTreeOptions)estimator;
                estimator.putOption("winner_id", (Object)Boolean.TRUE);
                break;
            }
        }
        Model model = estimator.encode(schema);
        encoder.addTransformer(model);
        List result = EstimatorUtil.export((Estimator)estimator, (String)predictFunc, (Schema)schema, (Model)model, (SkLearnEncoder)encoder);
        Output output = model.getOutput();
        if (output != null && output.hasOutputFields()) {
            List outputFields = output.getOutputFields();
            outputFields.clear();
        }
        return result;
    }

    public Estimator getEstimator() {
        return (Estimator)this.get("estimator_", Estimator.class);
    }

    public String getPredictFunc() {
        return this.getString("predict_func");
    }

    private static Schema createSchema(Estimator estimator, List<Feature> features, SkLearnEncoder encoder) {
        ContinuousLabel label = null;
        if (estimator.isSupervised()) {
            MiningFunction miningFunction = estimator.getMiningFunction();
            switch (miningFunction) {
                case CLASSIFICATION: {
                    List categories = EstimatorUtil.getClasses((Estimator)estimator);
                    DataType dataType = TypeUtil.getDataType((Collection)categories, (DataType)DataType.STRING);
                    label = new CategoricalLabel(dataType, categories);
                    break;
                }
                case REGRESSION: {
                    label = new ContinuousLabel(DataType.DOUBLE);
                    break;
                }
                default: {
                    throw new IllegalArgumentException();
                }
            }
        }
        return new Schema((PMMLEncoder)encoder, label, features);
    }
}

