/*
 * Decompiled with CFR 0.152.
 */
package sklearn.impute;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MissingValueTreatmentMethod;
import org.dmg.pmml.OpType;
import org.jpmml.converter.Feature;
import org.jpmml.converter.TypeUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.SkLearnTransformer;
import sklearn.impute.ImputerUtil;

public class SimpleImputer
extends SkLearnTransformer {
    private static final String STRATEGY_CONSTANT = "constant";
    private static final String STRATEGY_MEAN = "mean";
    private static final String STRATEGY_MEDIAN = "median";
    private static final String STRATEGY_MOST_FREQUENT = "most_frequent";

    public SimpleImputer(String module, String name) {
        super(module, name);
    }

    @Override
    public OpType getOpType() {
        String strategy;
        switch (strategy = this.getStrategy()) {
            case "constant": {
                DataType dataType = this.getDataType();
                return TypeUtil.getOpType((DataType)dataType);
            }
            case "mean": 
            case "median": {
                return OpType.CONTINUOUS;
            }
            case "most_frequent": {
                return OpType.CATEGORICAL;
            }
        }
        throw new IllegalArgumentException(strategy);
    }

    @Override
    public DataType getDataType() {
        String strategy = this.getStrategy();
        List<Object> statistics = this.getStatistics();
        switch (strategy) {
            case "constant": {
                return TypeUtil.getDataType(statistics, (DataType)DataType.STRING);
            }
            case "mean": 
            case "median": {
                return DataType.DOUBLE;
            }
            case "most_frequent": {
                return TypeUtil.getDataType(statistics, (DataType)DataType.STRING);
            }
        }
        throw new IllegalArgumentException(strategy);
    }

    @Override
    public int getNumberOfFeatures() {
        int[] shape = this.getStatisticsShape();
        return shape[0];
    }

    @Override
    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        Boolean addIndicator = this.getAddIndicator();
        Object missingValues = this.getMissingValues();
        List<Object> statistics = this.getStatistics();
        String strategy = this.getStrategy();
        ClassDictUtil.checkSize((Collection[])new Collection[]{features, statistics});
        if (ValueUtil.isNaN((Object)missingValues)) {
            missingValues = null;
        }
        MissingValueTreatmentMethod missingValueTreatment = SimpleImputer.parseStrategy(strategy);
        ArrayList<Feature> indicatorFeatures = new ArrayList<Feature>();
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (int i = 0; i < features.size(); ++i) {
            Feature feature = features.get(i);
            Object statistic = statistics.get(i);
            if (addIndicator.booleanValue()) {
                Feature indicatorFeature = ImputerUtil.encodeIndicatorFeature(this, feature, missingValues, encoder);
                indicatorFeatures.add(indicatorFeature);
            }
            feature = ImputerUtil.encodeFeature(this, feature, addIndicator, missingValues, statistic, missingValueTreatment, encoder);
            result.add(feature);
        }
        if (addIndicator.booleanValue()) {
            result.addAll(indicatorFeatures);
        }
        return result;
    }

    public Boolean getAddIndicator() {
        return this.getOptionalBoolean("add_indicator", Boolean.FALSE);
    }

    public Object getMissingValues() {
        return this.getOptionalScalar("missing_values");
    }

    public List<Object> getStatistics() {
        if (!this.hasattr("statistics_")) {
            return Collections.emptyList();
        }
        return this.getObjectArray("statistics_");
    }

    public int[] getStatisticsShape() {
        if (!this.hasattr("statistics_")) {
            return new int[]{0};
        }
        return this.getArrayShape("statistics_", 1);
    }

    public String getStrategy() {
        return (String)this.getEnum("strategy", arg_0 -> ((SimpleImputer)this).getString(arg_0), Arrays.asList(STRATEGY_CONSTANT, STRATEGY_MEAN, STRATEGY_MEDIAN, STRATEGY_MOST_FREQUENT));
    }

    private static MissingValueTreatmentMethod parseStrategy(String strategy) {
        switch (strategy) {
            case "constant": {
                return MissingValueTreatmentMethod.AS_VALUE;
            }
            case "mean": {
                return MissingValueTreatmentMethod.AS_MEAN;
            }
            case "median": {
                return MissingValueTreatmentMethod.AS_MEDIAN;
            }
            case "most_frequent": {
                return MissingValueTreatmentMethod.AS_MODE;
            }
        }
        throw new IllegalArgumentException(strategy);
    }
}

