/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sklearn;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import numpy.DType;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.Schema;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.model.ReflectionUtil;
import org.jpmml.model.UnsupportedAttributeException;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.PickleUtil;
import org.jpmml.python.PythonEncoder;
import sklearn.Estimator;
import sklearn.EstimatorUtil;
import sklearn.HasMultiType;
import sklearn.Step;
import sklearn.StepUtil;
import sklearn.ensemble.hist_gradient_boosting.TreePredictor;
import sklearn.neighbors.BinaryTree;
import sklearn.tree.Tree;
import sklearn2pmml.decoration.Alias;
import sklearn2pmml.decoration.Domain;

public class SkLearnEncoder
extends PythonEncoder {
    private Map<String, Domain> domains = new LinkedHashMap<String, Domain>();
    private Label label = null;
    private List<? extends Feature> features = Collections.emptyList();
    private Map<String, Feature> memory = new LinkedHashMap<String, Feature>();
    private Predicate predicate = null;
    private Model model = null;

    public void addTransformer(Model transformer) {
        if (this.hasModel()) {
            throw new IllegalStateException("Model is already defined");
        }
        super.addTransformer(transformer);
    }

    public Model encodeModel(Model model) {
        Predicate predicate = this.getPredicate();
        model = super.encodeModel(model);
        if (predicate == null) {
            return model;
        }
        MiningModel miningModel = (MiningModel)model;
        Segmentation segmentation = miningModel.requireSegmentation();
        Segmentation.MultipleModelMethod multipleModelMethod = segmentation.requireMultipleModelMethod();
        switch (multipleModelMethod) {
            case MODEL_CHAIN: {
                break;
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)segmentation, (Enum)multipleModelMethod);
            }
        }
        List segments = segmentation.requireSegments();
        Segment finalSegment = (Segment)segments.get(segments.size() - 1);
        finalSegment.setPredicate(predicate);
        Set miningFunctions = segments.stream().map(segment -> {
            Model segmentModel = segment.requireModel();
            return segmentModel.requireMiningFunction();
        }).collect(Collectors.toSet());
        if (miningFunctions.size() > 1) {
            miningModel.setMiningFunction(MiningFunction.MIXED);
        }
        return miningModel;
    }

    public Label initLabel(Estimator estimator, List<String> names) {
        List<? extends Feature> features = this.getFeatures();
        if (!features.isEmpty()) {
            throw new IllegalStateException();
        }
        Label label = estimator.encodeLabel(names, this);
        this.setLabel(label);
        return label;
    }

    public List<Feature> initFeatures(Step step, List<String> names) {
        HasMultiType hasMultiType = StepUtil.getType(step);
        ArrayList<Feature> features = new ArrayList<Feature>();
        for (int i = 0; i < names.size(); ++i) {
            String name = names.get(i);
            OpType opType = hasMultiType.getOpType(i);
            DataType dataType = hasMultiType.getDataType(i);
            DataField dataField = this.createDataField(name, opType, dataType);
            WildcardFeature feature = new WildcardFeature((PMMLEncoder)this, dataField);
            features.add((Feature)feature);
        }
        this.setFeatures(features);
        return features;
    }

    public Schema createSchema() {
        Label label = this.getLabel();
        List<? extends Feature> features = this.getFeatures();
        return new Schema((PMMLEncoder)this, label, features);
    }

    public List<Feature> export(Model model, String name) {
        return this.export(model, Collections.singletonList(name));
    }

    public List<Feature> export(Model model, List<String> names) {
        Output output = EstimatorUtil.getFinalOutput(model);
        if (output == null) {
            throw new IllegalArgumentException();
        }
        List outputFields = output.getOutputFields();
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (String name : names) {
            DerivedOutputField derivedOutputField = null;
            List<OutputField> nameOutputFields = SkLearnEncoder.selectOutputFields(name, outputFields);
            for (OutputField nameOutputField : nameOutputFields) {
                derivedOutputField = this.createDerivedField(model, nameOutputField, true);
            }
            Feature feature = derivedOutputField.toFeature((PMMLEncoder)this);
            result.add(feature);
            outputFields.removeAll(nameOutputFields);
        }
        return result;
    }

    public Feature exportPrediction(Model model, ScalarLabel scalarLabel) {
        String name = scalarLabel.isAnonymous() ? "predict" : FieldNameUtil.create((String)"predict", (Object[])new Object[]{scalarLabel.getName()});
        return this.exportPrediction(model, name, scalarLabel);
    }

    public Feature exportPrediction(Model model, String name, ScalarLabel scalarLabel) {
        OutputField outputField = ModelUtil.createPredictedField((String)name, (OpType)scalarLabel.getOpType(), (DataType)scalarLabel.getDataType()).setFinalResult(Boolean.valueOf(false));
        DerivedOutputField derivedOutputField = this.createDerivedField(model, outputField, false);
        return derivedOutputField.toFeature((PMMLEncoder)this);
    }

    public Feature exportProbability(Model model, Object value) {
        return this.exportProbability(model, FieldNameUtil.create((String)"probability", (Object[])new Object[]{value}), value);
    }

    public Feature exportProbability(Model model, String name, Object value) {
        OutputField probabilityOutputField = ModelUtil.createProbabilityField((String)name, (DataType)DataType.DOUBLE, (Object)value).setFinalResult(Boolean.valueOf(false));
        DerivedOutputField probabilityField = this.createDerivedField(model, probabilityOutputField, false);
        return probabilityField.toFeature((PMMLEncoder)this);
    }

    public DataField createDataField(String name) {
        return this.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
    }

    public DerivedField createDerivedField(String name, Expression expression) {
        return this.createDerivedField(name, OpType.CONTINUOUS, DataType.DOUBLE, expression);
    }

    public void addDerivedField(DerivedField derivedField) {
        try {
            super.addDerivedField(derivedField);
        }
        catch (RuntimeException re) {
            String name = derivedField.requireName();
            String message = "Field " + name + " is already defined. Please refactor the pipeline so that it would not contain duplicate field declarations, or use the " + Alias.class.getName() + " wrapper class to override the default name with a custom name (eg. " + Alias.formatAliasExample() + ")";
            throw new IllegalArgumentException(message, re);
        }
    }

    public void renameFeature(Feature feature, String renamedName) {
        String name = feature.getName();
        org.dmg.pmml.Field pmmlField = this.getField(name);
        if (pmmlField instanceof DataField) {
            throw new IllegalArgumentException("User input field " + name + " cannot be renamed");
        }
        DerivedField derivedField = this.removeDerivedField(name);
        try {
            Field nameField = Feature.class.getDeclaredField("name");
            ReflectionUtil.setFieldValue((Field)nameField, (Object)feature, (Object)renamedName);
        }
        catch (ReflectiveOperationException roe) {
            throw new RuntimeException(roe);
        }
        derivedField.setName(renamedName);
        this.addDerivedField(derivedField);
    }

    public void renameFeatures(List<Feature> features, List<String> renamedNames) {
        ClassDictUtil.checkSize((int)renamedNames.size(), (Collection[])new Collection[]{features});
        for (int i = 0; i < features.size(); ++i) {
            Feature feature = features.get(i);
            String renamedName = renamedNames.get(i);
            this.renameFeature(feature, renamedName);
        }
    }

    public boolean isFrozen(String name) {
        Map<String, Domain> domains = this.getDomains();
        return domains.containsKey(name);
    }

    public Domain getDomain(String name) {
        Map<String, Domain> domains = this.getDomains();
        return domains.get(name);
    }

    public void setDomain(String name, Domain domain) {
        Map<String, Domain> domains = this.getDomains();
        if (domain != null) {
            domains.put(name, domain);
        } else {
            domains.remove(name);
        }
    }

    public Map<String, Domain> getDomains() {
        return this.domains;
    }

    public Label getLabel() {
        return this.label;
    }

    public void setLabel(Label label) {
        this.label = label;
    }

    public List<? extends Feature> getFeatures() {
        return this.features;
    }

    public void setFeatures(List<? extends Feature> features) {
        this.features = Objects.requireNonNull(features);
    }

    public void memorize(String name, Feature feature) {
        Map<String, Feature> memory = this.getMemory();
        memory.put(name, feature);
    }

    public Feature recall(String name) {
        Map<String, Feature> memory = this.getMemory();
        return memory.get(name);
    }

    public Map<String, Feature> getMemory() {
        return this.memory;
    }

    public Predicate getPredicate() {
        return this.predicate;
    }

    public void setPredicate(Predicate predicate) {
        this.predicate = predicate;
    }

    public boolean hasModel() {
        Model model = this.getModel();
        return model != null;
    }

    public Model getModel() {
        return this.model;
    }

    public void setModel(Model model) {
        this.model = model;
    }

    public static boolean isPrediction(OutputField outputField) {
        ResultFeature resultFeature = outputField.getResultFeature();
        switch (resultFeature) {
            case PREDICTED_VALUE: 
            case TRANSFORMED_VALUE: 
            case DECISION: {
                return true;
            }
        }
        return false;
    }

    private static List<OutputField> selectOutputFields(String name, List<OutputField> outputFields) {
        ArrayList<OutputField> result = new ArrayList<OutputField>();
        for (OutputField outputField : outputFields) {
            boolean prediction = SkLearnEncoder.isPrediction(outputField);
            if (prediction) {
                result.add(outputField);
            }
            if (!Objects.equals(name, outputField.requireName())) continue;
            if (prediction) {
                return result;
            }
            return Collections.singletonList(outputField);
        }
        throw new IllegalArgumentException(name);
    }

    static {
        ClassLoader clazzLoader = SkLearnEncoder.class.getClassLoader();
        PickleUtil.init((ClassLoader)clazzLoader, (String)"sklearn2pmml.properties");
        DType.addDefinition(BinaryTree.DTYPE_NODEDATA);
        DType.addDefinition(Tree.DTYPE_TREE_OLD);
        DType.addDefinition(Tree.DTYPE_TREE_NEW);
        DType.addDefinition(TreePredictor.DTYPE_PREDICTOR_OLD);
        DType.addDefinition(TreePredictor.DTYPE_PREDICTOR_NEW);
    }
}

