/*
 * Decompiled with CFR 0.152.
 */
package sklearn.feature_selection;

import java.util.ArrayList;
import java.util.List;
import net.razorvine.pickle.objects.ClassDict;
import numpy.core.Scalar;
import org.jpmml.sklearn.ClassDictUtil;
import sklearn.Estimator;
import sklearn.EstimatorUtil;
import sklearn.HasEstimator;
import sklearn.Selector;

public class SelectFromModel
extends Selector
implements HasEstimator<Estimator> {
    public SelectFromModel(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        Estimator estimator = this.getEstimator();
        return estimator.getNumberOfFeatures();
    }

    @Override
    public List<Boolean> getSupportMask() {
        Estimator estimator = this.getEstimator();
        Number threshold = this.getThreshold();
        List<?> featureImportances = ClassDictUtil.getArray(estimator, "feature_importances_");
        if (featureImportances == null) {
            throw new IllegalArgumentException("The estimator object (" + ClassDictUtil.formatClass(estimator) + ") does not have a persistent 'feature_importances_' attribute");
        }
        ArrayList<Boolean> result = new ArrayList<Boolean>();
        for (int i = 0; i < featureImportances.size(); ++i) {
            Number featureImportance = (Number)featureImportances.get(i);
            result.add(featureImportance.doubleValue() >= threshold.doubleValue());
        }
        return result;
    }

    @Override
    public Estimator getEstimator() {
        ClassDict estimator = (ClassDict)this.get("estimator_");
        return EstimatorUtil.asEstimator(estimator);
    }

    public Number getThreshold() {
        Scalar threshold = (Scalar)this.get("threshold_");
        if (threshold == null) {
            throw new IllegalArgumentException("The selector object does not have a persistent 'threshold_' attribute");
        }
        return (Number)threshold.getOnlyElement();
    }
}

