/*
 * Decompiled with CFR 0.152.
 */
package sklearn;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import net.razorvine.pickle.objects.ClassDict;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.Schema;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import sklearn.HasNumberOfFeatures;
import sklearn.HasType;
import sklearn.Step;

public abstract class Estimator
extends Step
implements HasNumberOfFeatures,
HasType {
    private static final Logger logger = LoggerFactory.getLogger(Estimator.class);

    public Estimator(String module, String name) {
        super(module, name);
    }

    public abstract MiningFunction getMiningFunction();

    public abstract Model encodeModel(Schema var1);

    @Override
    public int getNumberOfFeatures() {
        if (this.containsKey("n_features_")) {
            return this.getInteger("n_features_");
        }
        return -1;
    }

    @Override
    public OpType getOpType() {
        return OpType.CONTINUOUS;
    }

    @Override
    public DataType getDataType() {
        return DataType.DOUBLE;
    }

    public boolean isSupervised() {
        MiningFunction miningFunction = this.getMiningFunction();
        switch (miningFunction) {
            case CLASSIFICATION: 
            case REGRESSION: {
                return true;
            }
            case CLUSTERING: {
                return false;
            }
        }
        throw new IllegalArgumentException();
    }

    public Model encode(Schema schema) {
        Model model = this.encodeModel(schema);
        this.addFeatureImportances(model, schema);
        return model;
    }

    public void addFeatureImportances(Model model, Schema schema) {
        List<? extends Number> featureImportances = this.getPMMLFeatureImportances();
        if (featureImportances == null) {
            featureImportances = this.getFeatureImportances();
        }
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        if (featureImportances != null) {
            ClassDictUtil.checkSize((Collection[])new Collection[]{features, featureImportances});
            for (int i = 0; i < features.size(); ++i) {
                Feature feature = (Feature)features.get(i);
                Number featureImportance = featureImportances.get(i);
                SkLearnEncoder encoder = (SkLearnEncoder)feature.getEncoder();
                encoder.addFeatureImportance(model, feature.getName(), featureImportance);
            }
        }
    }

    public Object getOption(String key, Object defaultValue) {
        Map<String, ?> pmmlOptions = this.getPMMLOptions();
        if (pmmlOptions != null && pmmlOptions.containsKey(key)) {
            return pmmlOptions.get(key);
        }
        if (this.containsKey(key)) {
            logger.warn("Attribute '" + ClassDictUtil.formatMember((ClassDict)this, (String)"pmml_options_") + "' is not set. Falling back to the surrogate attribute '" + ClassDictUtil.formatMember((ClassDict)this, (String)key) + "'");
            return this.get(key);
        }
        return defaultValue;
    }

    public boolean hasFeatureImportances() {
        return this.containsKey("feature_importances_") || this.containsKey("pmml_feature_importances_");
    }

    public List<? extends Number> getFeatureImportances() {
        if (!this.containsKey("feature_importances_")) {
            return null;
        }
        return this.getNumberArray("feature_importances_");
    }

    public List<? extends Number> getPMMLFeatureImportances() {
        if (!this.containsKey("pmml_feature_importances_")) {
            return null;
        }
        return this.getNumberArray("pmml_feature_importances_");
    }

    public Estimator setPMMLFeatureImportances(List<? extends Number> pmmlFeatureImportances) {
        this.put("pmml_feature_importances_", Estimator.toArray(pmmlFeatureImportances));
        return this;
    }

    public Map<String, ?> getPMMLOptions() {
        Object value = this.get("pmml_options_");
        if (value == null) {
            return null;
        }
        return this.getDict("pmml_options_");
    }

    public Estimator setPMMLOptions(Map<String, ?> pmmlOptions) {
        this.put("pmml_options_", pmmlOptions);
        return this;
    }

    public String getSkLearnVersion() {
        return this.getOptionalString("_sklearn_version");
    }
}

