/*
 * Decompiled with CFR 0.152.
 */
package sklearn.feature_selection;

import java.util.ArrayList;
import java.util.List;
import numpy.core.Scalar;
import org.jpmml.sklearn.ClassDictUtil;
import sklearn.Estimator;
import sklearn.HasEstimator;
import sklearn.Selector;
import sklearn2pmml.EstimatorProxy;
import sklearn2pmml.SelectorProxy;

public class SelectFromModel
extends Selector
implements HasEstimator<Estimator> {
    public SelectFromModel(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        Estimator estimator = this.getEstimator();
        return estimator.getNumberOfFeatures();
    }

    @Override
    public List<Boolean> getSupportMask() {
        List<Number> featureImportances;
        Estimator estimator = this.getEstimator();
        Number threshold = this.getThreshold();
        try {
            featureImportances = estimator.getArray("feature_importances_", Number.class);
        }
        catch (RuntimeException re) {
            String message = "The estimator object (" + ClassDictUtil.formatClass(estimator) + ") does not have a persistent 'feature_importances_' attribute. " + "Please use the " + EstimatorProxy.class.getName() + " wrapper class to give the estimator object a persistent state (eg. " + ClassDictUtil.formatProxyExample(EstimatorProxy.class, estimator) + ")";
            throw new IllegalArgumentException(message, re);
        }
        ArrayList<Boolean> result = new ArrayList<Boolean>();
        for (int i = 0; i < featureImportances.size(); ++i) {
            Number featureImportance = featureImportances.get(i);
            result.add(featureImportance.doubleValue() >= threshold.doubleValue());
        }
        return result;
    }

    @Override
    public Estimator getEstimator() {
        return this.get("estimator_", Estimator.class);
    }

    public Number getThreshold() {
        Scalar threshold;
        try {
            threshold = this.get("threshold_", Scalar.class);
        }
        catch (RuntimeException re) {
            String message = "The selector object (" + ClassDictUtil.formatClass(this) + ") does not have a persistent 'threshold_' attribute. " + "Please use the " + SelectorProxy.class.getName() + " wrapper class to give the selector object a persistent state (eg. " + ClassDictUtil.formatProxyExample(SelectorProxy.class, this) + ")";
            throw new IllegalArgumentException(message, re);
        }
        return (Number)threshold.getOnlyElement();
    }
}

