/*
 * Decompiled with CFR 0.152.
 */
package sklearn2pmml.pipeline;

import h2o.estimators.BaseEstimator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import numpy.core.NDArray;
import numpy.core.ScalarUtil;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Extension;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningBuildTask;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.True;
import org.dmg.pmml.Value;
import org.dmg.pmml.VerificationField;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.converter.visitors.AbstractExtender;
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.PyClassDict;
import org.jpmml.sklearn.SkLearnEncoder;
import org.jpmml.sklearn.TupleUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import sklearn.Classifier;
import sklearn.ClassifierUtil;
import sklearn.Estimator;
import sklearn.HasEstimator;
import sklearn.HasNumberOfFeatures;
import sklearn.Initializer;
import sklearn.Transformer;
import sklearn.TransformerUtil;
import sklearn.TypeUtil;
import sklearn.pipeline.Pipeline;
import sklearn2pmml.pipeline.Verification;

public class PMMLPipeline
extends Pipeline
implements HasEstimator<Estimator> {
    private static final Logger logger = LoggerFactory.getLogger(PMMLPipeline.class);

    public PMMLPipeline() {
        this("sklearn2pmml", "PMMLPipeline");
    }

    public PMMLPipeline(String module, String name) {
        super(module, name);
    }

    public PMML encodePMML() {
        List<? extends Transformer> transformers = this.getTransformers();
        Estimator estimator = this.getEstimator();
        Transformer predictTransformer = this.getPredictTransformer();
        Transformer predictProbaTransformer = this.getPredictProbaTransformer();
        Transformer applyTransformer = this.getApplyTransformer();
        List<String> activeFields = this.getActiveFields();
        ArrayList<String> probabilityFields = null;
        List<String> targetFields = this.getTargetFields();
        String repr = this.getRepr();
        Verification verification = this.getVerification();
        SkLearnEncoder encoder = new SkLearnEncoder();
        ContinuousLabel label = null;
        if (estimator.isSupervised()) {
            String targetField = null;
            if (targetFields != null) {
                ClassDictUtil.checkSize(1, targetFields);
                targetField = targetFields.get(0);
            }
            if (targetField == null) {
                targetField = "y";
                logger.warn("Attribute '" + ClassDictUtil.formatMember(this, "target_fields") + "' is not set. Assuming {} as the name of the target field", (Object)targetField);
            }
            MiningFunction miningFunction = estimator.getMiningFunction();
            switch (miningFunction) {
                case CLASSIFICATION: {
                    List<?> classes = ClassifierUtil.getClasses(estimator);
                    Map classExtensions = (Map)estimator.getOption("class_extensions", null);
                    DataType dataType = TypeUtil.getDataType(classes, DataType.STRING);
                    List<String> categories = ClassifierUtil.formatTargetCategories(classes);
                    DataField dataField = encoder.createDataField(FieldName.create((String)targetField), OpType.CATEGORICAL, dataType, categories);
                    ArrayList<1> visitors = new ArrayList<1>();
                    if (classExtensions != null) {
                        Set entries = classExtensions.entrySet();
                        Iterator iterator = entries.iterator();
                        while (iterator.hasNext()) {
                            Map.Entry entry = (Map.Entry)iterator.next();
                            String name = (String)entry.getKey();
                            final Map values = (Map)entry.getValue();
                            AbstractExtender valueExtender = new AbstractExtender(name){

                                public VisitorAction visit(Value pmmlValue) {
                                    Object value = values.get(pmmlValue.getValue());
                                    if (value != null) {
                                        value = ScalarUtil.decode(value);
                                        this.addExtension((PMMLObject)pmmlValue, ValueUtil.formatValue(value));
                                    }
                                    return super.visit(pmmlValue);
                                }
                            };
                            visitors.add(valueExtender);
                        }
                    }
                    for (Visitor visitor : visitors) {
                        visitor.applyTo((Visitable)dataField);
                    }
                    label = new CategoricalLabel(dataField);
                    break;
                }
                case REGRESSION: {
                    DataField dataField = encoder.createDataField(FieldName.create((String)targetField), OpType.CONTINUOUS, DataType.DOUBLE);
                    label = new ContinuousLabel(dataField);
                    break;
                }
                default: {
                    throw new IllegalArgumentException();
                }
            }
        }
        List<Feature> features = new ArrayList<Feature>();
        PyClassDict featureInitializer = estimator;
        try {
            Transformer transformer = TransformerUtil.getHead(transformers);
            if (transformer != null) {
                featureInitializer = transformer;
                if (!(transformer instanceof Initializer)) {
                    features = this.initFeatures(transformer, transformer.getOpType(), transformer.getDataType(), encoder);
                }
                features = this.encodeFeatures(features, encoder);
            } else {
                features = this.initFeatures(estimator, estimator.getOpType(), estimator.getDataType(), encoder);
            }
        }
        catch (UnsupportedOperationException uoe) {
            throw new IllegalArgumentException("The first transformer or estimator object (" + ClassDictUtil.formatClass(featureInitializer) + ") does not specify feature type information", uoe);
        }
        int numberOfFeatures = estimator.getNumberOfFeatures();
        if (numberOfFeatures > -1) {
            ClassDictUtil.checkSize(numberOfFeatures, features);
        }
        Schema schema = new Schema((Label)label, features);
        Model model = estimator.encodeModel(schema);
        if (predictTransformer != null) {
            OutputField predictField;
            Output output;
            if (model instanceof MiningModel) {
                MiningModel miningModel = (MiningModel)model;
                Model finalModel = PMMLPipeline.getFinalModel(miningModel);
                output = ModelUtil.ensureOutput((Model)finalModel);
            } else {
                output = ModelUtil.ensureOutput((Model)model);
            }
            FieldName name = FieldName.create((String)("predict(" + label.getName().getValue() + ")"));
            if (label instanceof ContinuousLabel) {
                predictField = ModelUtil.createPredictedField((FieldName)name, (DataType)label.getDataType(), (OpType)OpType.CONTINUOUS).setFinalResult(Boolean.valueOf(false));
            } else if (label instanceof CategoricalLabel) {
                predictField = ModelUtil.createPredictedField((FieldName)name, (DataType)label.getDataType(), (OpType)OpType.CATEGORICAL).setFinalResult(Boolean.valueOf(false));
            } else {
                throw new IllegalArgumentException();
            }
            output.addOutputFields(new OutputField[]{predictField});
            this.encodeOutput(predictTransformer, model, Collections.singletonList(predictField));
        }
        if (predictProbaTransformer != null) {
            CategoricalLabel categoricalLabel = (CategoricalLabel)label;
            List predictProbaFields = ModelUtil.createProbabilityFields((DataType)DataType.DOUBLE, (List)categoricalLabel.getValues());
            this.encodeOutput(predictProbaTransformer, model, predictProbaFields);
        }
        if (applyTransformer != null) {
            OutputField nodeIdField = ModelUtil.createEntityIdField((FieldName)FieldName.create((String)"nodeId")).setDataType(DataType.INTEGER);
            this.encodeOutput(applyTransformer, model, Collections.singletonList(nodeIdField));
        }
        if (estimator.isSupervised() && verification != null) {
            VerificationField verificationField;
            int i;
            if (activeFields == null) {
                throw new IllegalArgumentException();
            }
            int[] activeValuesShape = verification.getActiveValuesShape();
            int[] targetValuesShape = verification.getTargetValuesShape();
            ClassDictUtil.checkShapes(0, activeValuesShape, targetValuesShape);
            ClassDictUtil.checkShapes(1, activeFields.size(), (int[][])new int[][]{activeValuesShape});
            List<?> activeValues = verification.getActiveValues();
            List<?> targetValues = verification.getTargetValues();
            Object var22_27 = null;
            List<? extends Number> probabilityValues = null;
            boolean hasProbabilityValues = verification.hasProbabilityValues();
            if (estimator instanceof BaseEstimator) {
                BaseEstimator baseEstimator = (BaseEstimator)estimator;
                hasProbabilityValues &= baseEstimator.hasProbabilityDistribution();
            } else if (estimator instanceof Classifier) {
                Classifier classifier = (Classifier)estimator;
                hasProbabilityValues &= classifier.hasProbabilityDistribution();
            } else {
                hasProbabilityValues = false;
            }
            if (hasProbabilityValues) {
                int[] nArray = verification.getProbabilityValuesShape();
                probabilityFields = new ArrayList<String>();
                CategoricalLabel categoricalLabel = (CategoricalLabel)label;
                List values = categoricalLabel.getValues();
                for (String value : values) {
                    probabilityFields.add("probability(" + value + ")");
                }
                ClassDictUtil.checkShapes(0, activeValuesShape, nArray);
                ClassDictUtil.checkShapes(1, probabilityFields.size(), (int[][])new int[][]{nArray});
                probabilityValues = verification.getProbabilityValues();
            }
            Number precision = verification.getPrecision();
            Number zeroThreshold = verification.getZeroThreshold();
            int rows = activeValuesShape[0];
            LinkedHashMap<VerificationField, List> data = new LinkedHashMap<VerificationField, List>();
            if (activeFields != null) {
                for (i = 0; i < activeFields.size(); ++i) {
                    verificationField = ModelUtil.createVerificationField((FieldName)FieldName.create((String)activeFields.get(i)));
                    data.put(verificationField, CMatrixUtil.getColumn(activeValues, (int)rows, (int)activeFields.size(), (int)i));
                }
            }
            if (probabilityFields != null) {
                for (i = 0; i < probabilityFields.size(); ++i) {
                    verificationField = ModelUtil.createVerificationField((FieldName)FieldName.create((String)((String)probabilityFields.get(i)))).setPrecision(Double.valueOf(precision.doubleValue())).setZeroThreshold(Double.valueOf(zeroThreshold.doubleValue()));
                    data.put(verificationField, CMatrixUtil.getColumn(probabilityValues, (int)rows, (int)probabilityFields.size(), (int)i));
                }
            } else {
                for (i = 0; i < targetFields.size(); ++i) {
                    verificationField = ModelUtil.createVerificationField((FieldName)FieldName.create((String)targetFields.get(i)));
                    DataType dataType = label.getDataType();
                    switch (dataType) {
                        case DOUBLE: 
                        case FLOAT: {
                            verificationField.setPrecision(Double.valueOf(precision.doubleValue())).setZeroThreshold(Double.valueOf(zeroThreshold.doubleValue()));
                            break;
                        }
                    }
                    data.put(verificationField, CMatrixUtil.getColumn(targetValues, (int)rows, (int)targetFields.size(), (int)i));
                }
            }
            model.setModelVerification(ModelUtil.createModelVerification(data));
        }
        PMML pmml = encoder.encodePMML(model);
        if (repr != null) {
            Extension extension = new Extension().addContent(new Object[]{repr});
            MiningBuildTask miningBuildTask = new MiningBuildTask().addExtensions(new Extension[]{extension});
            pmml.setMiningBuildTask(miningBuildTask);
        }
        return pmml;
    }

    private List<Feature> initFeatures(PyClassDict object, OpType opType, DataType dataType, SkLearnEncoder encoder) {
        List<String> activeFields = this.getActiveFields();
        if (activeFields == null) {
            int numberOfFeatures = -1;
            if (object instanceof HasNumberOfFeatures) {
                HasNumberOfFeatures hasNumberOfFeatures = (HasNumberOfFeatures)((Object)object);
                numberOfFeatures = hasNumberOfFeatures.getNumberOfFeatures();
            }
            if (numberOfFeatures < 0) {
                throw new IllegalArgumentException("The first transformer or estimator object (" + ClassDictUtil.formatClass((Object)object) + ") does not specify the number of input features");
            }
            activeFields = new ArrayList<String>(numberOfFeatures);
            int max = numberOfFeatures;
            for (int i = 0; i < max; ++i) {
                activeFields.add("x" + String.valueOf(i + 1));
            }
            logger.warn("Attribute '" + ClassDictUtil.formatMember(this, "active_fields") + "' is not set. Assuming {} as the names of active fields", activeFields);
        }
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (String activeField : activeFields) {
            DataField dataField = encoder.createDataField(FieldName.create((String)activeField), opType, dataType);
            result.add((Feature)new WildcardFeature((PMMLEncoder)encoder, dataField));
        }
        return result;
    }

    private void encodeOutput(Transformer transformer, Model model, List<OutputField> outputFields) {
        SkLearnEncoder encoder = new SkLearnEncoder();
        ArrayList<Feature> features = new ArrayList<Feature>();
        final HashSet<FieldName> names = new HashSet<FieldName>();
        for (OutputField outputField : outputFields) {
            FieldName name = outputField.getName();
            DataField dataField = encoder.createDataField(name, outputField.getOpType(), outputField.getDataType());
            features.add((Feature)new WildcardFeature((PMMLEncoder)encoder, dataField));
            names.add(name);
        }
        transformer.encodeFeatures(features, encoder);
        class OutputFinder
        extends AbstractVisitor {
            private Output output = null;

            OutputFinder() {
            }

            public VisitorAction visit(Output output) {
                if (output.hasOutputFields()) {
                    List outputFields = output.getOutputFields();
                    HashSet<FieldName> definedNames = new HashSet<FieldName>();
                    for (OutputField outputField : outputFields) {
                        FieldName name = outputField.getName();
                        definedNames.add(name);
                    }
                    if (definedNames.containsAll(names)) {
                        this.setOutput(output);
                        return VisitorAction.TERMINATE;
                    }
                }
                return super.visit(output);
            }

            public Output getOutput() {
                return this.output;
            }

            private void setOutput(Output output) {
                this.output = output;
            }
        }
        OutputFinder outputFinder = new OutputFinder();
        outputFinder.applyTo((Visitable)model);
        Output output = outputFinder.getOutput();
        if (output == null) {
            throw new IllegalArgumentException();
        }
        Map derivedFields = encoder.getDerivedFields();
        for (DerivedField derivedField : derivedFields.values()) {
            OutputField outputField = new OutputField(derivedField.getName(), derivedField.getDataType()).setOpType(derivedField.getOpType()).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(derivedField.getExpression());
            output.addOutputFields(new OutputField[]{outputField});
        }
    }

    @Override
    public List<? extends Transformer> getTransformers() {
        List<Object[]> steps = this.getSteps();
        if (steps.size() > 0) {
            steps = steps.subList(0, steps.size() - 1);
        }
        return TupleUtil.extractElementList(steps, 1, Transformer.class);
    }

    @Override
    public Estimator getEstimator() {
        List<Object[]> steps = this.getSteps();
        if (steps.size() < 1) {
            throw new IllegalArgumentException("Expected one or more elements, got zero elements");
        }
        Object[] lastStep = steps.get(steps.size() - 1);
        return TupleUtil.extractElement(lastStep, 1, Estimator.class);
    }

    @Override
    public List<Object[]> getSteps() {
        return super.getSteps();
    }

    public PMMLPipeline setSteps(List<Object[]> steps) {
        this.put("steps", steps);
        return this;
    }

    public Transformer getPredictTransformer() {
        return this.getTransformer("predict_transformer");
    }

    public Transformer getPredictProbaTransformer() {
        return this.getTransformer("predict_proba_transformer");
    }

    public Transformer getApplyTransformer() {
        return this.getTransformer("apply_transformer");
    }

    private Transformer getTransformer(String key) {
        Object transformer = this.get(key);
        if (transformer == null) {
            return null;
        }
        return this.get(key, Transformer.class);
    }

    public List<String> getActiveFields() {
        if (!this.containsKey("active_fields")) {
            return null;
        }
        return this.getArray("active_fields", String.class);
    }

    public PMMLPipeline setActiveFields(List<String> activeFields) {
        this.put("active_fields", PMMLPipeline.toArray(activeFields));
        return this;
    }

    public List<String> getTargetFields() {
        if (this.containsKey("target_field")) {
            return Collections.singletonList((String)this.get("target_field"));
        }
        if (!this.containsKey("target_fields")) {
            return null;
        }
        return this.getArray("target_fields", String.class);
    }

    public PMMLPipeline setTargetFields(List<String> targetFields) {
        this.put("target_fields", PMMLPipeline.toArray(targetFields));
        return this;
    }

    public String getRepr() {
        return (String)this.get("repr_");
    }

    public PMMLPipeline setRepr(String repr) {
        this.put("repr_", repr);
        return this;
    }

    public Verification getVerification() {
        return (Verification)((Object)this.get("verification"));
    }

    public PMMLPipeline setVerification(Verification verification) {
        this.put("verification", (Object)verification);
        return this;
    }

    private static Model getFinalModel(MiningModel miningModel) {
        Segmentation segmentation = miningModel.getSegmentation();
        Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
        switch (multipleModelMethod) {
            case SELECT_FIRST: 
            case SELECT_ALL: {
                throw new IllegalArgumentException();
            }
            case MODEL_CHAIN: {
                List segments = segmentation.getSegments();
                Segment lastSegment = (Segment)segments.get(segments.size() - 1);
                Predicate predicate = lastSegment.getPredicate();
                if (!(predicate instanceof True)) {
                    throw new IllegalArgumentException();
                }
                Model model = lastSegment.getModel();
                if (model instanceof MiningModel) {
                    MiningModel finalMiningModel = (MiningModel)model;
                    return PMMLPipeline.getFinalModel(finalMiningModel);
                }
                return model;
            }
        }
        return miningModel;
    }

    private static NDArray toArray(List<String> strings) {
        NDArray result = new NDArray();
        result.put("data", strings);
        result.put("fortran_order", Boolean.FALSE);
        return result;
    }
}

