/*
 * Decompiled with CFR 0.152.
 */
package sklearn.preprocessing;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.InvalidValueTreatmentMethod;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.BinaryThresholdFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Decorator;
import org.jpmml.converter.Feature;
import org.jpmml.converter.InvalidValueDecorator;
import org.jpmml.converter.MissingValueFeature;
import org.jpmml.converter.ObjectFeature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.TypeUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.HasArray;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.preprocessing.BaseEncoder;
import sklearn.preprocessing.EncoderUtil;

public class MultiOneHotEncoder
extends BaseEncoder {
    public MultiOneHotEncoder(String module, String name) {
        super(module, name);
    }

    @Override
    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        List<List<?>> categories = this.getCategories();
        Object drop = this.getDrop();
        List<Integer> dropIdx = drop != null ? this.getDropIdx() : null;
        String handleUnknown = this.getHandleUnknown();
        Boolean infrequentEnabled = this.getInfrequentEnabled();
        List<List<Integer>> infrequentIndices = infrequentEnabled != false ? this.getInfrequentIndices() : null;
        ClassDictUtil.checkSize((Collection[])new Collection[]{categories, features});
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (int i = 0; i < features.size(); ++i) {
            Integer index;
            CategoricalFeature categoricalFeature;
            InvalidValueDecorator invalidValueDecorator;
            Feature feature = features.get(i);
            ArrayList<Double> featureCategories = new ArrayList<Double>((Collection)categories.get(i));
            List<Integer> featureInfrequentIndices = infrequentEnabled != false ? infrequentIndices.get(i) : null;
            boolean featureInfrequentEnabled = infrequentEnabled;
            if (featureInfrequentIndices == null || featureInfrequentIndices.isEmpty()) {
                featureInfrequentEnabled = false;
            }
            Object infrequentCategory = null;
            if (featureInfrequentEnabled) {
                infrequentCategory = MultiOneHotEncoder.getInfrequentCategory(feature);
            }
            switch (handleUnknown) {
                case "error": {
                    invalidValueDecorator = new InvalidValueDecorator(InvalidValueTreatmentMethod.RETURN_INVALID, null);
                    break;
                }
                case "ignore": {
                    invalidValueDecorator = new InvalidValueDecorator(InvalidValueTreatmentMethod.AS_IS, null);
                    break;
                }
                case "infrequent_if_exist": {
                    if (featureInfrequentEnabled) {
                        invalidValueDecorator = new InvalidValueDecorator(InvalidValueTreatmentMethod.AS_VALUE, infrequentCategory);
                        break;
                    }
                    invalidValueDecorator = new InvalidValueDecorator(InvalidValueTreatmentMethod.AS_IS, null);
                    break;
                }
                default: {
                    throw new IllegalArgumentException(handleUnknown);
                }
            }
            EncoderUtil.addDecorator(feature, (Decorator)invalidValueDecorator);
            if (feature instanceof BinaryThresholdFeature) {
                BinaryThresholdFeature thresholdFeature = (BinaryThresholdFeature)feature;
                ContinuousFeature continuousFeature = thresholdFeature.toContinuousFeature();
                encoder.toCategorical(continuousFeature.getName(), null);
                feature = continuousFeature;
            } else if (feature instanceof CategoricalFeature) {
                categoricalFeature = (CategoricalFeature)feature;
                if (MultiOneHotEncoder.hasMissingCategory(featureCategories)) {
                    if (MultiOneHotEncoder.hasMissingCategory(categoricalFeature.getValues())) {
                        ClassDictUtil.checkSize((Collection[])new Collection[]{featureCategories, categoricalFeature.getValues()});
                        featureCategories = new ArrayList(categoricalFeature.getValues());
                    } else {
                        ClassDictUtil.checkSize((Collection[])new Collection[]{MultiOneHotEncoder.dropMissingCategory(featureCategories), categoricalFeature.getValues()});
                        featureCategories = new ArrayList(categoricalFeature.getValues());
                        DataType dataType = categoricalFeature.getDataType();
                        switch (dataType) {
                            case FLOAT: 
                            case DOUBLE: {
                                featureCategories.add(Double.NaN);
                                break;
                            }
                            default: {
                                featureCategories.add(null);
                                break;
                            }
                        }
                    }
                } else {
                    ClassDictUtil.checkSize((Collection[])new Collection[]{featureCategories, categoricalFeature.getValues()});
                    featureCategories = new ArrayList(categoricalFeature.getValues());
                }
            } else if (feature instanceof ObjectFeature) {
                categoricalFeature = (ObjectFeature)feature;
            } else if (feature instanceof WildcardFeature) {
                WildcardFeature wildcardFeature = (WildcardFeature)feature;
                feature = MultiOneHotEncoder.hasMissingCategory(featureCategories) ? wildcardFeature.toCategoricalFeature(MultiOneHotEncoder.dropMissingCategory(featureCategories)) : wildcardFeature.toCategoricalFeature(featureCategories);
                CategoricalFeature categoricalFeature2 = (CategoricalFeature)feature;
                DataType dataType = TypeUtil.getDataType((Collection)categoricalFeature2.getValues(), (DataType)DataType.STRING);
                DataField dataField = (DataField)categoricalFeature2.getField();
                if (dataField.requireDataType() != dataType) {
                    dataField.setDataType(dataType);
                }
            } else {
                throw new IllegalArgumentException();
            }
            if (featureInfrequentEnabled) {
                if (infrequentCategory == null || featureCategories.contains(infrequentCategory)) {
                    throw new IllegalArgumentException();
                }
                List featureInfrequentCategories = MultiOneHotEncoder.selectValues(featureCategories, featureInfrequentIndices);
                featureCategories.removeAll(featureInfrequentCategories);
                feature = EncoderUtil.encodeRegroupFeature(this, feature, featureInfrequentCategories, infrequentCategory, encoder);
            }
            if (dropIdx != null && (index = dropIdx.get(i)) != null) {
                int intIndex = dropIdx.get(i);
                featureCategories.remove(intIndex);
            }
            for (int j = 0; j < featureCategories.size(); ++j) {
                Object category = featureCategories.get(j);
                if (EncoderUtil.isMissingCategory(category)) {
                    result.add((Feature)new MissingValueFeature((PMMLEncoder)encoder, feature));
                    continue;
                }
                result.add((Feature)new BinaryFeature((PMMLEncoder)encoder, feature, category));
            }
            if (!featureInfrequentEnabled) continue;
            result.add((Feature)new BinaryFeature((PMMLEncoder)encoder, feature, infrequentCategory));
        }
        return result;
    }

    public Object getDrop() {
        return this.getOptionalObject("drop");
    }

    public List<Integer> getDropIdx() {
        List dropIdx = this.getNumberArray("drop_idx_");
        if (dropIdx == null) {
            return null;
        }
        return Lists.transform((List)dropIdx, number -> number != null ? ValueUtil.asInteger((Number)number) : null);
    }

    public Boolean getInfrequentEnabled() {
        return this.getOptionalBoolean("_infrequent_enabled", false);
    }

    public List<List<Integer>> getInfrequentIndices() {
        return EncoderUtil.transformInfrequentIndices(this.getList("_infrequent_indices", HasArray.class));
    }

    private static boolean hasMissingCategory(List<?> categories) {
        if (!categories.isEmpty()) {
            Object lastCategory = categories.get(categories.size() - 1);
            return EncoderUtil.isMissingCategory(lastCategory);
        }
        return false;
    }

    private static <E> List<E> dropMissingCategory(List<E> categories) {
        if (MultiOneHotEncoder.hasMissingCategory(categories)) {
            return categories.subList(0, categories.size() - 1);
        }
        return categories;
    }

    private static Object getInfrequentCategory(Feature feature) {
        DataType dataType = feature.getDataType();
        switch (dataType) {
            case STRING: {
                return "infrequent";
            }
            case FLOAT: 
            case DOUBLE: 
            case INTEGER: {
                return -999;
            }
        }
        return null;
    }

    private static <E> List<E> selectValues(List<E> values, Collection<Integer> indices) {
        if (indices == null || indices.isEmpty()) {
            return Collections.emptyList();
        }
        ArrayList<E> result = new ArrayList<E>();
        for (Integer index : indices) {
            E value = values.get(index);
            result.add(value);
        }
        return result;
    }
}

