/*
 * Decompiled with CFR 0.152.
 */
package sklearn;

import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import net.razorvine.pickle.objects.ClassDict;
import numpy.core.NDArrayUtil;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.python.ClassDictUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import sklearn.HasApplyField;
import sklearn.HasClasses;
import sklearn.HasNumberOfOutputs;
import sklearn.Step;
import sklearn.StepUtil;

public abstract class Estimator
extends Step
implements HasNumberOfOutputs {
    public static final String FIELD_APPLY = "apply";
    public static final String FIELD_DECISION_FUNCTION = "decisionFunction";
    public static final String FIELD_PREDICT = "predict";
    private static final Logger logger = LoggerFactory.getLogger(Estimator.class);

    public Estimator(String module, String name) {
        super(module, name);
    }

    public abstract MiningFunction getMiningFunction();

    public abstract Model encodeModel(Schema var1);

    @Override
    public int getNumberOfFeatures() {
        if (this.containsKey("n_features_in_") && this.get("n_features_in_") != null) {
            return this.getInteger("n_features_in_");
        }
        if (this.containsKey("n_features_")) {
            return this.getInteger("n_features_");
        }
        return -1;
    }

    @Override
    public int getNumberOfOutputs() {
        if (this.containsKey("n_outputs_")) {
            return this.getInteger("n_outputs_");
        }
        return -1;
    }

    @Override
    public OpType getOpType() {
        return OpType.CONTINUOUS;
    }

    @Override
    public DataType getDataType() {
        return DataType.DOUBLE;
    }

    public boolean isSupervised() {
        MiningFunction miningFunction = this.getMiningFunction();
        switch (miningFunction) {
            case CLASSIFICATION: 
            case REGRESSION: {
                return true;
            }
            case CLUSTERING: {
                return false;
            }
        }
        throw new IllegalArgumentException();
    }

    public String getAlgorithmName() {
        return this.getClassName();
    }

    public Model encode(Schema schema) {
        String algorithmName;
        String pmmlName;
        this.checkLabel(schema.getLabel());
        this.checkFeatures(schema.getFeatures());
        Model model = this.encodeModel(schema);
        String modelName = model.getModelName();
        if (modelName == null && (pmmlName = this.getPMMLName()) != null) {
            model.setModelName(pmmlName);
        }
        if ((algorithmName = model.getAlgorithmName()) == null) {
            String pyClassName = this.getAlgorithmName();
            model.setAlgorithmName(pyClassName);
        }
        this.addFeatureImportances(model, schema);
        return model;
    }

    public void checkLabel(Label label) {
        boolean supervised = this.isSupervised();
        if (supervised) {
            if (label == null) {
                throw new IllegalArgumentException("Expected a label, got no label");
            }
        } else if (label != null) {
            throw new IllegalArgumentException("Expected no label, got " + label);
        }
    }

    public void checkFeatures(List<? extends Feature> features) {
        StepUtil.checkNumberOfFeatures(this, features);
    }

    public void addFeatureImportances(Model model, Schema schema) {
        List<? extends Number> featureImportances = this.getPMMLFeatureImportances();
        if (featureImportances == null) {
            featureImportances = this.getFeatureImportances();
        }
        ModelEncoder encoder = (ModelEncoder)schema.getEncoder();
        List features = schema.getFeatures();
        if (featureImportances != null) {
            ClassDictUtil.checkSize((Collection[])new Collection[]{features, featureImportances});
            for (int i = 0; i < features.size(); ++i) {
                Feature feature = (Feature)features.get(i);
                Number featureImportance = featureImportances.get(i);
                encoder.addFeatureImportance(model, feature, featureImportance);
            }
        }
    }

    public Object getOption(String key, Object defaultValue) {
        Map<String, ?> pmmlOptions = this.getPMMLOptions();
        if (pmmlOptions != null && pmmlOptions.containsKey(key)) {
            return pmmlOptions.get(key);
        }
        if (this.containsKey(key)) {
            logger.warn("Attribute '" + ClassDictUtil.formatMember((ClassDict)this, (String)"pmml_options_") + "' is not set. Falling back to the surrogate attribute '" + ClassDictUtil.formatMember((ClassDict)this, (String)key) + "'");
            return this.get(key);
        }
        return defaultValue;
    }

    public void putOption(String key, Object value) {
        this.putOptions(Collections.singletonMap(key, value));
    }

    public void putOptions(Map<String, ?> options) {
        Map<String, ?> pmmlOptions = this.getPMMLOptions();
        if (pmmlOptions == null) {
            pmmlOptions = new LinkedHashMap();
            this.setPMMLOptions(pmmlOptions);
        }
        pmmlOptions.putAll(options);
    }

    public boolean hasFeatureImportances() {
        return this.containsKey("feature_importances_") || this.containsKey("pmml_feature_importances_");
    }

    public List<? extends Number> getFeatureImportances() {
        if (!this.containsKey("feature_importances_")) {
            return null;
        }
        return this.getNumberArray("feature_importances_");
    }

    public List<? extends Number> getPMMLFeatureImportances() {
        if (!this.containsKey("pmml_feature_importances_")) {
            return null;
        }
        return this.getNumberArray("pmml_feature_importances_");
    }

    public Estimator setPMMLFeatureImportances(List<? extends Number> pmmlFeatureImportances) {
        this.put("pmml_feature_importances_", NDArrayUtil.toArray(pmmlFeatureImportances));
        return this;
    }

    public Map<String, ?> getPMMLOptions() {
        Object value = this.get("pmml_options_");
        if (value == null) {
            return null;
        }
        return this.getDict("pmml_options_");
    }

    public Estimator setPMMLOptions(Map<String, ?> pmmlOptions) {
        this.put("pmml_options_", pmmlOptions);
        return this;
    }

    public Object getPMMLSegmentId() {
        return this.getOptionalScalar("pmml_segment_id_");
    }

    public Estimator setPMMLSegmentId(Object segmentId) {
        if (segmentId != null) {
            this.put("pmml_segment_id_", segmentId);
        } else {
            this.remove("pmml_segment_id_");
        }
        return this;
    }

    public String getSkLearnVersion() {
        return this.getOptionalString("_sklearn_version");
    }

    public List<OutputField> createPredictProbaFields(DataType dataType, CategoricalLabel categoricalLabel) {
        Object pmmlSegmentId = this.getPMMLSegmentId();
        if (!(this instanceof HasClasses)) {
            throw new IllegalArgumentException();
        }
        HasClasses hasClasses = (HasClasses)((Object)this);
        List values = categoricalLabel.getValues();
        return values.stream().map(value -> {
            String name = pmmlSegmentId != null ? FieldNameUtil.create((String)"probability", (Object[])new Object[]{pmmlSegmentId, value}) : FieldNameUtil.create((String)"probability", (Object[])new Object[]{value});
            return ModelUtil.createProbabilityField((String)name, (DataType)dataType, (Object)value);
        }).collect(Collectors.toList());
    }

    public OutputField createApplyField(DataType dataType) {
        Object pmmlSegmentId = this.getPMMLSegmentId();
        if (!(this instanceof HasApplyField)) {
            throw new IllegalArgumentException();
        }
        HasApplyField hasApplyField = (HasApplyField)((Object)this);
        String name = hasApplyField.getApplyField();
        if (pmmlSegmentId != null) {
            name = FieldNameUtil.create((String)name, (Object[])new Object[]{pmmlSegmentId});
        }
        return ModelUtil.createEntityIdField((String)name, (DataType)dataType);
    }

    public OutputField encodeApplyOutput(Model model, DataType dataType, List<?> values) {
        OutputField applyField = this.createApplyField(dataType);
        if (values != null && !values.isEmpty()) {
            PMMLUtil.addValues((Field)applyField, values);
        }
        Output output = ModelUtil.ensureOutput((Model)model);
        output.getOutputFields().add(applyField);
        return applyField;
    }
}

