/*
 * Decompiled with CFR 0.152.
 */
package sklearn2pmml;

import com.google.common.base.CharMatcher;
import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import org.jpmml.sklearn.TupleUtil;
import sklearn.Classifier;
import sklearn.Estimator;
import sklearn.TypeUtil;
import sklearn.pipeline.Pipeline;
import sklearn_pandas.DataFrameMapper;

public class PMMLPipeline
extends Pipeline {
    public PMMLPipeline(String module, String name) {
        super(module, name);
    }

    public PMML encodePMML() {
        OpType opType;
        DataFrameMapper dataFrameMapper = this.getMapper();
        Estimator estimator = this.getEstimator();
        while (estimator instanceof Pipeline) {
            Pipeline pipeline = (Pipeline)estimator;
            estimator = pipeline.getEstimator();
        }
        SkLearnEncoder encoder = new SkLearnEncoder();
        Object label = null;
        if (estimator.isSupervised()) {
            String targetField = this.getTargetField();
            if (targetField == null) {
                targetField = "y";
            }
            opType = OpType.CONTINUOUS;
            DataType dataType = DataType.DOUBLE;
            List<String> targetCategories = null;
            if (estimator instanceof Classifier) {
                Classifier classifier = (Classifier)estimator;
                List<?> classes = classifier.getClasses();
                if (classes == null || classes.isEmpty()) {
                    throw new IllegalArgumentException();
                }
                opType = OpType.CATEGORICAL;
                dataType = TypeUtil.getDataType(classes, DataType.STRING);
                targetCategories = PMMLPipeline.formatTargetCategories(classes);
            }
            DataField dataField = encoder.createDataField(FieldName.create((String)targetField), opType, dataType, targetCategories);
            label = targetCategories != null && targetCategories.size() > 0 ? new CategoricalLabel(dataField) : new ContinuousLabel(dataField);
        }
        if (dataFrameMapper != null) {
            dataFrameMapper.encodeFeatures(encoder);
        } else {
            List<String> activeFields = this.getActiveFields();
            if (activeFields == null) {
                activeFields = new ArrayList<String>();
                int max = this.getNumberOfFeatures();
                for (int i = 0; i < max; ++i) {
                    activeFields.add("x" + String.valueOf(i + 1));
                }
            }
            opType = this.getOpType();
            DataType dataType = this.getDataType();
            for (String activeField : activeFields) {
                DataField dataField = encoder.createDataField(FieldName.create((String)activeField), opType, dataType);
                encoder.addRow(Collections.singletonList(activeField), Collections.singletonList(new WildcardFeature((PMMLEncoder)encoder, dataField)));
            }
        }
        Schema schema = new Schema(label, encoder.getFeatures());
        Model model = this.encodeModel(schema, encoder);
        return encoder.encodePMML(model);
    }

    public DataFrameMapper getMapper() {
        Object[] mapperStep = this.getMapperStep();
        if (mapperStep != null) {
            return (DataFrameMapper)((Object)TupleUtil.extractElement(mapperStep, 1));
        }
        return null;
    }

    public Object[] getMapperStep() {
        Object object;
        List<Object[]> transformerSteps = super.getTransformerSteps();
        if (transformerSteps.size() > 0 && (object = TupleUtil.extractElement(transformerSteps.get(0), 1)) instanceof DataFrameMapper) {
            return transformerSteps.get(0);
        }
        return null;
    }

    @Override
    public List<Object[]> getTransformerSteps() {
        Object object;
        List<Object[]> transformerSteps = super.getTransformerSteps();
        if (transformerSteps.size() > 0 && (object = TupleUtil.extractElement(transformerSteps.get(0), 1)) instanceof DataFrameMapper) {
            transformerSteps = transformerSteps.subList(1, transformerSteps.size());
        }
        return transformerSteps;
    }

    public List<String> getActiveFields() {
        return ClassDictUtil.getArray(this, "active_fields");
    }

    public String getTargetField() {
        return (String)this.get("target_field");
    }

    private static List<String> formatTargetCategories(List<?> objects) {
        Function<Object, String> function = new Function<Object, String>(){

            public String apply(Object object) {
                String targetCategory = ValueUtil.formatValue((Object)object);
                if (targetCategory == null || CharMatcher.WHITESPACE.matchesAnyOf((CharSequence)targetCategory)) {
                    throw new IllegalArgumentException(targetCategory);
                }
                return targetCategory;
            }
        };
        return new ArrayList<String>(Lists.transform(objects, (Function)function));
    }
}

