/*
 * Decompiled with CFR 0.152.
 */
package sklearn2pmml.pipeline;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import net.razorvine.pickle.objects.ClassDict;
import numpy.core.NDArrayUtil;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Extension;
import org.dmg.pmml.Header;
import org.dmg.pmml.MiningBuildTask;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.VerificationField;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.ScalarLabelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import sklearn.Estimator;
import sklearn.HasClasses;
import sklearn.Step;
import sklearn.Transformer;
import sklearn.pipeline.SkLearnPipeline;
import sklearn2pmml.HasPMMLOptions;
import sklearn2pmml.decoration.Domain;
import sklearn2pmml.pipeline.Verification;

public class PMMLPipeline
extends SkLearnPipeline
implements HasPMMLOptions<PMMLPipeline> {
    private static final Logger logger = LoggerFactory.getLogger(PMMLPipeline.class);

    public PMMLPipeline() {
        this("sklearn2pmml", "PMMLPipeline");
    }

    public PMMLPipeline(String module, String name) {
        super(module, name);
    }

    @Override
    public PMML encodePMML() {
        List<? extends Number> featureImportances;
        SkLearnEncoder encoder = new SkLearnEncoder();
        List<String> activeFields = this.getActiveFields();
        List<String> probabilityFields = null;
        List<String> targetFields = this.getTargetFields();
        Map<?, ?> header = this.getHeader();
        String repr = this.getRepr();
        Transformer predictTransformer = this.getPredictTransformer();
        Transformer predictProbaTransformer = this.getPredictProbaTransformer();
        Transformer applyTransformer = this.getApplyTransformer();
        Verification verification = this.getVerification();
        Estimator estimator = null;
        if (this.hasFinalEstimator()) {
            estimator = this.getFinalEstimator();
            targetFields = this.initLabel(targetFields, encoder);
        }
        activeFields = this.initFeatures(activeFields, encoder);
        if (estimator == null) {
            return this.encodePMML(header, null, repr, encoder);
        }
        Schema schema = encoder.createSchema();
        Model model = estimator.encode(schema);
        encoder.setModel(model);
        if (!estimator.hasFeatureImportances() && (featureImportances = this.getPMMLFeatureImportances()) != null) {
            ClassDictUtil.checkSize((Collection[])new Collection[]{activeFields, featureImportances});
            for (int i = 0; i < activeFields.size(); ++i) {
                String activeField = activeFields.get(i);
                Number featureImportance = featureImportances.get(i);
                DataField dataField = encoder.getDataField(activeField);
                if (dataField == null) {
                    throw new IllegalArgumentException("Field " + activeField + " is undefined");
                }
                WildcardFeature feature = new WildcardFeature((PMMLEncoder)encoder, dataField);
                encoder.addFeatureImportance(model, (Feature)feature, featureImportance);
            }
        }
        if (predictTransformer != null || predictProbaTransformer != null || applyTransformer != null) {
            Model finalModel = MiningModelUtil.getFinalModel((Model)model);
            encoder.setModel(finalModel);
            Label label = schema.getLabel();
            Output output = ModelUtil.ensureOutput((Model)finalModel);
            if (predictTransformer != null) {
                List scalarLabels = ScalarLabelUtil.toScalarLabels((Label)label);
                ArrayList<OutputField> predictFields = new ArrayList<OutputField>();
                for (ScalarLabel scalarLabel : scalarLabels) {
                    OutputField predictField = ModelUtil.createPredictedField((String)FieldNameUtil.create((String)"predict", (Object[])new Object[]{scalarLabel.getName()}), (OpType)scalarLabel.getOpType(), (DataType)scalarLabel.getDataType()).setFinalResult(Boolean.valueOf(false));
                    output.addOutputFields(new OutputField[]{predictField});
                    predictFields.add(predictField);
                }
                this.encodeOutput(output, predictFields, predictTransformer, encoder);
            }
            if (predictProbaTransformer != null) {
                CategoricalLabel categoricalLabel = (CategoricalLabel)label;
                List<OutputField> predictProbaFields = estimator.createPredictProbaFields(DataType.DOUBLE, (DiscreteLabel)categoricalLabel);
                this.encodeOutput(output, predictProbaFields, predictProbaTransformer, encoder);
            }
            if (applyTransformer != null) {
                OutputField applyField = estimator.createApplyField(DataType.INTEGER);
                this.encodeOutput(output, Collections.singletonList(applyField), applyTransformer, encoder);
            }
            encoder.setModel(model);
        }
        if (estimator.isSupervised()) {
            if (verification == null) {
                logger.warn("Model verification data is not set. Use method '" + ClassDictUtil.formatMember((ClassDict)this, (String)"verify(X)") + "' to correct this deficiency");
            } else {
                Label label = schema.getLabel();
                List<?> activeValues = verification.getActiveValues();
                int[] activeValuesShape = verification.getActiveValuesShape();
                ClassDictUtil.checkShapes((int)1, (int)activeFields.size(), (int[][])new int[][]{activeValuesShape});
                int rows = activeValuesShape[0];
                LinkedHashMap<VerificationField, List> data = new LinkedHashMap<VerificationField, List>();
                if (activeFields != null) {
                    for (int i = 0; i < activeFields.size(); ++i) {
                        VerificationField verificationField = ModelUtil.createVerificationField((String)activeFields.get(i));
                        Domain domain = encoder.getDomain(verificationField.requireField());
                        data.put(verificationField, CMatrixUtil.getColumn(PMMLPipeline.cleanValues(domain, activeValues), (int)rows, (int)activeFields.size(), (int)i));
                    }
                }
                Number precision = verification.getPrecision();
                Number zeroThreshold = verification.getZeroThreshold();
                List scalarLabels = ScalarLabelUtil.toScalarLabels((Label)label);
                boolean hasProbabilityValues = verification.hasProbabilityValues();
                if (estimator instanceof HasClasses) {
                    HasClasses hasClasses = (HasClasses)((Object)estimator);
                    hasProbabilityValues &= hasClasses.hasProbabilityDistribution();
                }
                if (hasProbabilityValues) {
                    List<? extends Number> probabilityValues = verification.getProbabilityValues();
                    int[] probabilityValuesShape = verification.getProbabilityValuesShape();
                    ClassDictUtil.checkShapes((int)0, (int[][])new int[][]{activeValuesShape, probabilityValuesShape});
                    ClassDictUtil.checkSize((int)1, (Collection[])new Collection[]{scalarLabels});
                    ScalarLabel scalarLabel = (ScalarLabel)scalarLabels.get(0);
                    probabilityFields = this.initProbabilityFields((CategoricalLabel)scalarLabel);
                    ClassDictUtil.checkShapes((int)1, (int)probabilityFields.size(), (int[][])new int[][]{probabilityValuesShape});
                    for (int i = 0; i < probabilityFields.size(); ++i) {
                        VerificationField verificationField = ModelUtil.createVerificationField((String)probabilityFields.get(i)).setPrecision(precision).setZeroThreshold(zeroThreshold);
                        data.put(verificationField, CMatrixUtil.getColumn(PMMLPipeline.cleanValues(null, probabilityValues), (int)rows, (int)probabilityFields.size(), (int)i));
                    }
                } else {
                    List<?> targetValues = verification.getTargetValues();
                    int[] targetValuesShape = verification.getTargetValuesShape();
                    ClassDictUtil.checkShapes((int)0, (int[][])new int[][]{activeValuesShape, targetValuesShape});
                    ClassDictUtil.checkSize((Collection[])new Collection[]{targetFields, scalarLabels});
                    for (int i = 0; i < targetFields.size(); ++i) {
                        VerificationField verificationField = ModelUtil.createVerificationField((String)targetFields.get(i));
                        ScalarLabel scalarLabel = (ScalarLabel)scalarLabels.get(i);
                        DataType dataType = scalarLabel.getDataType();
                        switch (dataType) {
                            case DOUBLE: 
                            case FLOAT: {
                                verificationField.setPrecision(precision).setZeroThreshold(zeroThreshold);
                                break;
                            }
                        }
                        Domain domain = encoder.getDomain(verificationField.requireField());
                        data.put(verificationField, CMatrixUtil.getColumn(PMMLPipeline.cleanValues(domain, targetValues), (int)rows, (int)targetFields.size(), (int)i));
                    }
                }
                model.setModelVerification(ModelUtil.createModelVerification(data));
            }
        }
        return this.encodePMML(header, model, repr, encoder);
    }

    private PMML encodePMML(Map<?, ?> header, Model model, String repr, SkLearnEncoder encoder) {
        PMML pmml = encoder.encodePMML(model);
        if (header != null) {
            Header pmmlHeader = pmml.requireHeader();
            pmmlHeader.setCopyright((String)header.get("copyright"));
            pmmlHeader.setDescription((String)header.get("description"));
            pmmlHeader.setModelVersion((String)header.get("modelVersion"));
        }
        if (repr != null) {
            MiningBuildTask miningBuildTask = new MiningBuildTask().addExtensions(new Extension[]{PMMLUtil.createExtension((String)"repr", (Object[])new Object[]{repr})});
            pmml.setMiningBuildTask(miningBuildTask);
        }
        return pmml;
    }

    private void encodeOutput(Output output, List<OutputField> outputFields, Transformer transformer, SkLearnEncoder encoder) {
        SkLearnEncoder outputEncoder = new SkLearnEncoder();
        Model model = encoder.getModel();
        if (model != null) {
            outputEncoder.setModel(model);
        }
        ArrayList<Feature> features = new ArrayList<Feature>();
        for (OutputField outputField : outputFields) {
            DataField dataField = outputEncoder.createDataField(outputField.requireName(), outputField.requireOpType(), outputField.requireDataType());
            features.add((Feature)new WildcardFeature((PMMLEncoder)outputEncoder, dataField));
        }
        transformer.encode(features, outputEncoder);
        Collection derivedFields = outputEncoder.getDerivedFields().values();
        Iterator it = derivedFields.iterator();
        while (it.hasNext()) {
            OutputField outputField;
            DerivedField derivedField = (DerivedField)it.next();
            if (derivedField instanceof DerivedOutputField) {
                DerivedOutputField derivedOutputField = (DerivedOutputField)derivedField;
                outputField = derivedOutputField.getOutputField();
            } else {
                outputField = new OutputField(derivedField.requireName(), derivedField.requireOpType(), derivedField.requireDataType()).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setFinalResult(Boolean.valueOf(!it.hasNext())).setExpression(derivedField.requireExpression());
            }
            output.addOutputFields(new OutputField[]{outputField});
        }
        Map defineFunctions = outputEncoder.getDefineFunctions();
        for (DefineFunction defineFunction : defineFunctions.values()) {
            encoder.addDefineFunction(defineFunction);
        }
    }

    @Override
    public List<Object[]> getSteps() {
        return super.getSteps();
    }

    @Override
    public PMMLPipeline setSteps(List<Object[]> steps) {
        return (PMMLPipeline)super.setSteps(steps);
    }

    @Override
    public Map<String, ?> getPMMLOptions() {
        if (this.hasFinalEstimator()) {
            Estimator estimator = this.getFinalEstimator();
            return estimator.getPMMLOptions();
        }
        return null;
    }

    @Override
    public PMMLPipeline setPMMLOptions(Map<String, ?> pmmlOptions) {
        if (this.hasFinalEstimator()) {
            Estimator estimator = this.getFinalEstimator();
            estimator.setPMMLOptions((Map)pmmlOptions);
        }
        return this;
    }

    public Map<?, ?> getHeader() {
        return (Map)this.getOptional("header", Map.class);
    }

    public List<? extends Number> getPMMLFeatureImportances() {
        if (!this.containsKey("pmml_feature_importances_")) {
            return null;
        }
        return this.getNumberArray("pmml_feature_importances_");
    }

    public Transformer getPredictTransformer() {
        return this.getTransformer("predict_transformer");
    }

    public Transformer getPredictProbaTransformer() {
        return this.getTransformer("predict_proba_transformer");
    }

    public Transformer getApplyTransformer() {
        return this.getTransformer("apply_transformer");
    }

    private Transformer getTransformer(String key) {
        return (Transformer)this.getOptional(key, Transformer.class);
    }

    public List<String> getActiveFields() {
        if (!this.containsKey("active_fields")) {
            return null;
        }
        return this.getListLike("active_fields", String.class);
    }

    public PMMLPipeline setActiveFields(List<String> activeFields) {
        this.put("active_fields", NDArrayUtil.toArray(activeFields));
        return this;
    }

    public List<String> getTargetFields() {
        if (this.containsKey("target_field")) {
            return Collections.singletonList(this.getOptionalString("target_field"));
        }
        if (!this.containsKey("target_fields")) {
            return null;
        }
        return this.getListLike("target_fields", String.class);
    }

    public PMMLPipeline setTargetFields(List<String> targetFields) {
        this.put("target_fields", NDArrayUtil.toArray(targetFields));
        return this;
    }

    public String getRepr() {
        return this.getOptionalString("repr_");
    }

    public PMMLPipeline setRepr(String repr) {
        this.put("repr_", repr);
        return this;
    }

    public Verification getVerification() {
        return (Verification)((Object)this.getOptional("verification", Verification.class));
    }

    public PMMLPipeline setVerification(Verification verification) {
        this.put("verification", (Object)verification);
        return this;
    }

    @Override
    protected List<String> initTargetFields(Estimator estimator) {
        List<String> targetFields = super.initTargetFields(estimator);
        logger.warn("Attribute '" + ClassDictUtil.formatMember((ClassDict)this, (String)"target_fields") + "' is not set. Assuming {} as the name(s) of the target field(s)", targetFields);
        return targetFields;
    }

    @Override
    protected List<String> initActiveFields(Step step) {
        List<String> activeFields = super.initActiveFields(step);
        logger.warn("Attribute '" + ClassDictUtil.formatMember((ClassDict)this, (String)"active_fields") + "' is not set. Assuming {} as the names of active fields", activeFields);
        return activeFields;
    }

    private List<String> initProbabilityFields(CategoricalLabel categoricalLabel) {
        ArrayList<String> probabilityFields = new ArrayList<String>();
        List values = categoricalLabel.getValues();
        for (Object value : values) {
            probabilityFields.add(FieldNameUtil.create((String)"probability", (Object[])new Object[]{value}));
        }
        return probabilityFields;
    }

    private static List<?> cleanValues(Domain domain, List<?> values) {
        Function<Object, Object> function = new Function<Object, Object>(){

            public Object apply(Object value) {
                Domain.checkValue(value);
                if (ValueUtil.isNaN((Object)value)) {
                    return null;
                }
                return value;
            }
        };
        return Lists.transform(values, (Function)function);
    }
}

