/*
 * Decompiled with CFR 0.152.
 */
package category_encoders;

import category_encoders.CategoryEncoder;
import category_encoders.CategoryEncoderUtil;
import category_encoders.MapEncoder;
import category_encoders.MapFeature;
import com.google.common.base.Function;
import com.google.common.base.Functions;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import numpy.core.ScalarUtil;
import org.dmg.pmml.Field;
import org.dmg.pmml.InvalidValueTreatmentMethod;
import org.jpmml.converter.Decorator;
import org.jpmml.converter.Feature;
import org.jpmml.converter.InvalidValueDecorator;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.ValueUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import pandas.core.Series;
import pandas.core.SeriesUtil;
import sklearn.preprocessing.EncoderUtil;

public class CountEncoder
extends MapEncoder {
    public CountEncoder(String module, String name) {
        super(module, name);
    }

    @Override
    public String functionName() {
        return "count";
    }

    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        List<?> cols = this.getCols();
        Boolean dropInvariant = this.getDropInvariant();
        String handleMissing = this.getHandleMissing();
        String handleUnknown = this.getHandleUnknown();
        Boolean normalize = this.getNormalize();
        Map<Object, Series> mapping = this.getMapping();
        Map<Object, Map<Object, String>> minGroupCategories = this.getMinGroupCategories();
        if (dropInvariant.booleanValue()) {
            throw new IllegalArgumentException();
        }
        Object missingCategory = null;
        switch (handleMissing) {
            case "error": 
            case "return_nan": {
                break;
            }
            case "count": 
            case "value": {
                missingCategory = CategoryEncoder.CATEGORY_NAN;
                break;
            }
            default: {
                throw new IllegalArgumentException(handleMissing);
            }
        }
        if (handleUnknown == null) {
            throw new IllegalArgumentException();
        }
        Integer defaultValue = null;
        switch (handleUnknown) {
            case "error": {
                break;
            }
            case "value": {
                defaultValue = this.getDefaultValue();
                break;
            }
            default: {
                throw new IllegalArgumentException(handleUnknown);
            }
        }
        ClassDictUtil.checkSize((Collection[])new Collection[]{features, cols});
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (int i = 0; i < features.size(); ++i) {
            Feature feature = features.get(i);
            Object col = cols.get(i);
            Series series = mapping.get(col);
            Map categoryCounts = SeriesUtil.toMap((Series)series, (Function)Functions.identity(), (Function)(normalize != false ? ValueUtil::asDouble : ValueUtil::asInteger));
            Map<Object, String> leftoverCategories = minGroupCategories.get(col);
            if (leftoverCategories != null) {
                String leftoverCategory = (String)Iterables.getOnlyElement(new HashSet<String>(leftoverCategories.values()));
                Number leftoverCount = (Number)categoryCounts.remove(leftoverCategory);
                if (leftoverCount == null) {
                    throw new IllegalArgumentException();
                }
                Object categories = leftoverCategories.keySet();
                Iterator iterator = categories.iterator();
                while (iterator.hasNext()) {
                    Object category = iterator.next();
                    categoryCounts.put(category, leftoverCount);
                }
            }
            ArrayList categories = new ArrayList(categoryCounts.keySet());
            Field field = encoder.toCategorical(feature.getName(), EncoderUtil.filterCategories(categories));
            switch (handleUnknown) {
                case "value": {
                    EncoderUtil.addDecorator((Field)field, (Decorator)new InvalidValueDecorator(InvalidValueTreatmentMethod.AS_IS, null), (ModelEncoder)encoder);
                    break;
                }
            }
            MapFeature mapFeature = new MapFeature((PMMLEncoder)encoder, feature, categoryCounts, missingCategory, (Number)defaultValue){

                @Override
                public String getDerivedName() {
                    return CountEncoder.this.createFieldName(CountEncoder.this.functionName(), new Object[]{this.getName()});
                }
            };
            result.add((Feature)mapFeature);
        }
        return result;
    }

    public Integer getDefaultValue() {
        Object handleUnknown = this.getOptionalObject("handle_unknown");
        if (handleUnknown instanceof String) {
            return 0;
        }
        return this.getInteger("handle_unknown");
    }

    @Override
    public String getHandleUnknown() {
        Object handleUnknown = this.getOptionalObject("handle_unknown");
        if (handleUnknown instanceof Integer) {
            return "value";
        }
        return this.getOptionalString("handle_unknown");
    }

    public Boolean getNormalize() {
        return this.getBoolean("normalize");
    }

    public Map<Object, Map<Object, String>> getMinGroupCategories() {
        Map minGroupCategories = (Map)this.get("_min_group_categories", Map.class);
        return CategoryEncoderUtil.toTransformedMap(minGroupCategories, key -> ScalarUtil.decode((Object)key), value -> (Map)value);
    }
}

