/*
 * Decompiled with CFR 0.152.
 */
package sklearn;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import numpy.core.ScalarUtil;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Value;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.VisitorAction;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.MultiLabel;
import org.jpmml.converter.OrdinalLabel;
import org.jpmml.converter.TypeUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.visitors.AbstractExtender;
import org.jpmml.python.CastFunction;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.HasArray;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Estimator;
import sklearn.HasClasses;

public abstract class Classifier
extends Estimator
implements HasClasses {
    public static final String FIELD_PROBABILITY = "probability";

    public Classifier(String module, String name) {
        super(module, name);
    }

    @Override
    public MiningFunction getMiningFunction() {
        return MiningFunction.CLASSIFICATION;
    }

    @Override
    public boolean isSupervised() {
        return true;
    }

    @Override
    public int getNumberOfOutputs() {
        int numberOfOutputs = super.getNumberOfOutputs();
        if (numberOfOutputs == -1) {
            numberOfOutputs = 1;
        }
        return numberOfOutputs;
    }

    @Override
    public List<?> getClasses() {
        if (this.hasattr("pmml_classes_")) {
            return this.getClasses("pmml_classes_");
        }
        return this.getClasses("classes_");
    }

    protected List<?> getClasses(String name) {
        List values = this.getListLike(name);
        values = values.stream().map(value -> {
            if (value instanceof HasArray) {
                HasArray hasArray = (HasArray)value;
                return Classifier.canonicalizeValues(hasArray.getArrayContent());
            }
            return value;
        }).collect(Collectors.toList());
        return Classifier.canonicalizeValues(values);
    }

    @Override
    public boolean hasProbabilityDistribution() {
        return true;
    }

    @Override
    public Label encodeLabel(List<String> names, SkLearnEncoder encoder) {
        List<?> classes = this.getClasses();
        if (names.size() == 1) {
            return this.encodeLabel(names.get(0), classes, encoder);
        }
        if (names.size() >= 2) {
            ArrayList<DiscreteLabel> labels = new ArrayList<DiscreteLabel>();
            for (int i = 0; i < names.size(); ++i) {
                final String name = names.get(i);
                CastFunction castFunction = new CastFunction<List<?>>(List.class){

                    public String formatMessage(Object object) {
                        return "The categories object of the " + (String)(name != null ? "'" + name + "' " : "<un-named> ") + " target field (" + ClassDictUtil.formatClass((Object)object) + ") is not supported";
                    }
                };
                List categories = (List)castFunction.apply(classes.get(i));
                DiscreteLabel label = this.encodeLabel(name, categories, encoder);
                labels.add(label);
            }
            return new MultiLabel(labels);
        }
        throw new IllegalArgumentException();
    }

    protected DiscreteLabel encodeLabel(String name, List<?> categories, SkLearnEncoder encoder) {
        DataType dataType = TypeUtil.getDataType(categories, (DataType)DataType.STRING);
        return this.encodeLabel(name, OpType.CATEGORICAL, dataType, categories, encoder);
    }

    protected DiscreteLabel encodeLabel(String name, OpType opType, DataType dataType, List<?> categories, SkLearnEncoder encoder) {
        if (name != null) {
            DataField dataField = encoder.createDataField(name, opType, dataType, categories);
            Map classExtensions = (Map)this.getOption("class_extensions", null);
            if (classExtensions != null) {
                this.addClassExtensions(dataField, classExtensions);
            }
            switch (opType) {
                case CATEGORICAL: {
                    return new CategoricalLabel((Field)dataField);
                }
                case ORDINAL: {
                    return new OrdinalLabel((Field)dataField);
                }
            }
            throw new IllegalArgumentException();
        }
        switch (opType) {
            case CATEGORICAL: {
                return new CategoricalLabel(dataType, categories);
            }
            case ORDINAL: {
                return new OrdinalLabel(dataType, categories);
            }
        }
        throw new IllegalArgumentException();
    }

    private void addClassExtensions(DataField dataField, Map<String, Map<String, ?>> classExtensions) {
        ArrayList<2> visitors = new ArrayList<2>();
        if (classExtensions != null) {
            Set<Map.Entry<String, Map<String, ?>>> entries = classExtensions.entrySet();
            for (Map.Entry entry : entries) {
                String name = (String)entry.getKey();
                final Map values = (Map)entry.getValue();
                AbstractExtender valueExtender = new AbstractExtender(name){

                    public VisitorAction visit(Value pmmlValue) {
                        Object value = values.get(pmmlValue.requireValue());
                        if (value != null) {
                            value = ScalarUtil.decode(value);
                            this.addExtension((PMMLObject)pmmlValue, ValueUtil.asString(value));
                        }
                        return super.visit(pmmlValue);
                    }
                };
                visitors.add(valueExtender);
            }
        }
        for (Visitor visitor : visitors) {
            visitor.applyTo((Visitable)dataField);
        }
    }

    public List<OutputField> encodePredictProbaOutput(Model model, DataType dataType, DiscreteLabel discreteLabel) {
        List<OutputField> predictProbaFields = this.createPredictProbaFields(dataType, discreteLabel);
        model = MiningModelUtil.getFinalModel((Model)model);
        Output output = ModelUtil.ensureOutput((Model)model);
        output.getOutputFields().addAll(predictProbaFields);
        return predictProbaFields;
    }

    public static Object canonicalizeValue(Object value) {
        if (value instanceof Long) {
            Long longValue = (Long)value;
            return Math.toIntExact(longValue);
        }
        return value;
    }

    public static List<?> canonicalizeValues(List<?> values) {
        return values.stream().map(value -> Classifier.canonicalizeValue(value)).collect(Collectors.toList());
    }
}

