/*
 * Decompiled with CFR 0.152.
 */
package sklearn2pmml;

import com.google.common.base.CharMatcher;
import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import net.razorvine.pickle.objects.ClassDict;
import numpy.core.NDArray;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Extension;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningBuildTask;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import org.jpmml.sklearn.TupleUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import sklearn.Estimator;
import sklearn.EstimatorUtil;
import sklearn.HasEstimator;
import sklearn.HasNumberOfFeatures;
import sklearn.Initializer;
import sklearn.Transformer;
import sklearn.TransformerUtil;
import sklearn.TypeUtil;
import sklearn.pipeline.Pipeline;

public class PMMLPipeline
extends Pipeline
implements HasEstimator<Estimator> {
    private static final Logger logger = LoggerFactory.getLogger(PMMLPipeline.class);

    public PMMLPipeline() {
        this("sklearn2pmml", "PMMLPipeline");
    }

    public PMMLPipeline(String module, String name) {
        super(module, name);
    }

    public PMML encodePMML() {
        List<Transformer> transformers = this.getTransformers();
        Estimator estimator = this.getEstimator();
        String repr = this.getRepr();
        SkLearnEncoder encoder = new SkLearnEncoder();
        ContinuousLabel label = null;
        if (estimator.isSupervised()) {
            String targetField = null;
            List<String> targetFields = this.getTargetFields();
            if (targetFields != null) {
                ClassDictUtil.checkSize(1, targetFields);
                targetField = targetFields.get(0);
            }
            if (targetField == null) {
                targetField = "y";
                logger.warn("The 'target_fields' attribute is not set. Assuming {} as the name of the target field", (Object)targetField);
            }
            MiningFunction miningFunction = estimator.getMiningFunction();
            switch (miningFunction) {
                case CLASSIFICATION: {
                    List<?> classes = EstimatorUtil.getClasses(estimator);
                    DataField dataField = encoder.createDataField(FieldName.create((String)targetField), OpType.CATEGORICAL, TypeUtil.getDataType(classes, DataType.STRING), PMMLPipeline.formatTargetCategories(classes));
                    label = new CategoricalLabel(dataField);
                    break;
                }
                case REGRESSION: {
                    DataField dataField = encoder.createDataField(FieldName.create((String)targetField), OpType.CONTINUOUS, DataType.DOUBLE);
                    label = new ContinuousLabel(dataField);
                    break;
                }
                default: {
                    throw new IllegalArgumentException();
                }
            }
        }
        List<Feature> features = new ArrayList<Feature>();
        Transformer transformer = TransformerUtil.getHead(transformers);
        if (transformer != null) {
            if (!(transformer instanceof Initializer)) {
                features = this.initFeatures(transformer, transformer.getOpType(), transformer.getDataType(), encoder);
            }
            features = this.encodeFeatures(features, encoder);
        } else {
            features = this.initFeatures(estimator, estimator.getOpType(), estimator.getDataType(), encoder);
        }
        int numberOfFeatures = estimator.getNumberOfFeatures();
        if (numberOfFeatures > -1) {
            ClassDictUtil.checkSize(numberOfFeatures, features);
        }
        Schema schema = new Schema((Label)label, features);
        Model model = estimator.encodeModel(schema, encoder);
        PMML pmml = encoder.encodePMML(model);
        if (repr != null) {
            Extension extension = new Extension().addContent(new Object[]{repr});
            MiningBuildTask miningBuildTask = new MiningBuildTask().addExtensions(new Extension[]{extension});
            pmml.setMiningBuildTask(miningBuildTask);
        }
        return pmml;
    }

    private List<Feature> initFeatures(ClassDict object, OpType opType, DataType dataType, SkLearnEncoder encoder) {
        List<String> activeFields = this.getActiveFields();
        if (activeFields == null) {
            int numberOfFeatures = -1;
            if (object instanceof HasNumberOfFeatures) {
                HasNumberOfFeatures hasNumberOfFeatures = (HasNumberOfFeatures)object;
                numberOfFeatures = hasNumberOfFeatures.getNumberOfFeatures();
            }
            if (numberOfFeatures < 0) {
                throw new IllegalArgumentException("The first transformer or estimator object (" + ClassDictUtil.formatClass(object) + ") does not specify the number of input features");
            }
            activeFields = new ArrayList<String>(numberOfFeatures);
            int max = numberOfFeatures;
            for (int i = 0; i < max; ++i) {
                activeFields.add("x" + String.valueOf(i + 1));
            }
            logger.warn("The 'active_fields' attribute is not set. Assuming {} as the names of active fields", activeFields);
        }
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (String activeField : activeFields) {
            DataField dataField = encoder.createDataField(FieldName.create((String)activeField), opType, dataType);
            result.add((Feature)new WildcardFeature((PMMLEncoder)encoder, dataField));
        }
        return result;
    }

    @Override
    public List<Transformer> getTransformers() {
        List<Object[]> steps = this.getSteps();
        if (steps.size() > 0) {
            steps = steps.subList(0, steps.size() - 1);
        }
        return TransformerUtil.asTransformerList(TupleUtil.extractElementList(steps, 1));
    }

    @Override
    public Estimator getEstimator() {
        List<Object[]> steps = this.getSteps();
        if (steps.size() < 1) {
            throw new IllegalArgumentException("Expected one or more elements, got zero elements");
        }
        Object[] lastStep = steps.get(steps.size() - 1);
        return EstimatorUtil.asEstimator(TupleUtil.extractElement(lastStep, 1));
    }

    @Override
    public List<Object[]> getSteps() {
        return super.getSteps();
    }

    public PMMLPipeline setSteps(List<Object[]> steps) {
        this.put("steps", steps);
        return this;
    }

    public String getRepr() {
        return (String)this.get("repr_");
    }

    public PMMLPipeline setRepr(String repr) {
        this.put("repr_", repr);
        return this;
    }

    public List<String> getActiveFields() {
        if (!this.containsKey("active_fields")) {
            return null;
        }
        return ClassDictUtil.getArray(this, "active_fields");
    }

    public PMMLPipeline setActiveFields(List<String> activeFields) {
        this.put("active_fields", PMMLPipeline.toArray(activeFields));
        return this;
    }

    public List<String> getTargetFields() {
        if (this.containsKey("target_field")) {
            return Collections.singletonList((String)this.get("target_field"));
        }
        if (!this.containsKey("target_fields")) {
            return null;
        }
        return ClassDictUtil.getArray(this, "target_fields");
    }

    public PMMLPipeline setTargetFields(List<String> targetFields) {
        this.put("target_fields", PMMLPipeline.toArray(targetFields));
        return this;
    }

    private static NDArray toArray(List<String> strings) {
        NDArray result = new NDArray();
        result.put("data", strings);
        result.put("fortran_order", Boolean.FALSE);
        return result;
    }

    private static List<String> formatTargetCategories(List<?> objects) {
        Function<Object, String> function = new Function<Object, String>(){

            public String apply(Object object) {
                String targetCategory = ValueUtil.formatValue((Object)object);
                if (targetCategory == null || CharMatcher.WHITESPACE.matchesAnyOf((CharSequence)targetCategory)) {
                    throw new IllegalArgumentException(targetCategory);
                }
                return targetCategory;
            }
        };
        return Lists.transform(objects, (Function)function);
    }
}

