/*
 * Decompiled with CFR 0.152.
 */
package category_encoders;

import category_encoders.CategoryEncoder;
import category_encoders.CategoryEncoderUtil;
import category_encoders.MapEncoder;
import category_encoders.MapFeature;
import com.google.common.base.Function;
import com.google.common.base.Functions;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import numpy.core.ScalarUtil;
import org.dmg.pmml.Field;
import org.dmg.pmml.InvalidValueTreatmentMethod;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.Decorator;
import org.jpmml.converter.Feature;
import org.jpmml.converter.InvalidValueDecorator;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.ValueUtil;
import org.jpmml.python.HasArray;
import org.jpmml.sklearn.SkLearnEncoder;
import pandas.core.BlockManager;
import pandas.core.DataFrame;
import pandas.core.Index;
import pandas.core.Series;
import pandas.core.SeriesUtil;
import pandas.core.SingleBlockManager;
import sklearn.preprocessing.EncoderUtil;

public abstract class MeanEncoder
extends MapEncoder {
    public MeanEncoder(String module, String name) {
        super(module, name);
    }

    public abstract MeanFunction createFunction();

    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        List<?> cols = this.getCols();
        Boolean dropInvariant = this.getDropInvariant();
        String handleMissing = this.getHandleMissing();
        String handleUnknown = this.getHandleUnknown();
        Map<Object, Series> mapping = this.getMapping();
        if (dropInvariant.booleanValue()) {
            throw new IllegalArgumentException();
        }
        Object missingCategory = null;
        switch (handleMissing) {
            case "error": 
            case "return_nan": {
                break;
            }
            case "value": {
                missingCategory = CategoryEncoder.CATEGORY_NAN;
                break;
            }
            default: {
                throw new IllegalArgumentException(handleMissing);
            }
        }
        Double defaultValue = null;
        switch (handleUnknown) {
            case "error": {
                break;
            }
            case "value": {
                defaultValue = this.getMean();
                break;
            }
            default: {
                throw new IllegalArgumentException(handleUnknown);
            }
        }
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (int i = 0; i < features.size(); ++i) {
            Feature feature = features.get(i);
            Object col = cols.get(i);
            Series series = mapping.get(col);
            Map categoryMeans = SeriesUtil.toMap((Series)series, (Function)Functions.identity(), ValueUtil::asDouble);
            ArrayList categories = new ArrayList(categoryMeans.keySet());
            Field field = encoder.toCategorical(feature.getName(), EncoderUtil.filterCategories(categories));
            switch (handleUnknown) {
                case "value": {
                    EncoderUtil.addDecorator((Field)field, (Decorator)new InvalidValueDecorator(InvalidValueTreatmentMethod.AS_IS, null), (ModelEncoder)encoder);
                    break;
                }
            }
            MapFeature mapFeature = new MapFeature((PMMLEncoder)encoder, feature, categoryMeans, missingCategory, (Number)defaultValue){

                @Override
                public String getDerivedName() {
                    return MeanEncoder.this.createFieldName(MeanEncoder.this.functionName(), new Object[]{this.getName()});
                }
            };
            result.add((Feature)mapFeature);
        }
        return result;
    }

    @Override
    public Map<Object, Series> getMapping() {
        Map mapping = (Map)this.get("mapping", Map.class);
        return CategoryEncoderUtil.toTransformedMap(mapping, key -> ScalarUtil.decode((Object)key), value -> MeanEncoder.toMeanSeries((DataFrame)value, this.createFunction()));
    }

    public Double getMean() {
        return ValueUtil.asDouble((Number)this.getNumber("_mean"));
    }

    private static Series toMeanSeries(DataFrame dataFrame, MeanFunction function) {
        List countValues;
        List sumValues;
        BlockManager blockManager = dataFrame.getData();
        List axes = blockManager.getAxesArray();
        if (axes.size() != 2) {
            throw new IllegalArgumentException();
        }
        List firstDim = ((Index)axes.get(0)).getValues();
        List secondDim = ((Index)axes.get(1)).getValues();
        if (!Arrays.asList("sum", "count").equals(firstDim)) {
            throw new IllegalArgumentException();
        }
        List blockValues = blockManager.getBlockValues();
        if (blockValues.size() == 2) {
            sumValues = ((HasArray)blockValues.get(0)).getArrayContent();
            countValues = ((HasArray)blockValues.get(1)).getArrayContent();
        } else {
            HasArray blockValue = (HasArray)Iterables.getOnlyElement((Iterable)blockValues);
            List blockValueContent = blockValue.getArrayContent();
            int[] blockValueShape = blockValue.getArrayShape();
            sumValues = CMatrixUtil.getRow((List)blockValueContent, (int)blockValueShape[0], (int)blockValueShape[1], (int)0);
            countValues = CMatrixUtil.getRow((List)blockValueContent, (int)blockValueShape[0], (int)blockValueShape[1], (int)1);
        }
        final ArrayList<Double> meanValues = new ArrayList<Double>();
        for (int i = 0; i < sumValues.size(); ++i) {
            Double sum = ValueUtil.asDouble((Number)((Number)sumValues.get(i)));
            Integer count = ValueUtil.asInteger((Number)((Number)countValues.get(i)));
            Double mean = function.apply(sum, count);
            meanValues.add(mean);
        }
        HasArray hasArray = new HasArray(){

            public List<?> getArrayContent() {
                return meanValues;
            }

            public int[] getArrayShape() {
                return new int[]{meanValues.size()};
            }

            public Object getArrayType() {
                throw new UnsupportedOperationException();
            }
        };
        SingleBlockManager singleBlockManager = new SingleBlockManager();
        singleBlockManager.setOnlyBlockItem((Index)axes.get(1));
        singleBlockManager.setOnlyBlockValue(hasArray);
        Series result = new Series();
        result.setBlockManager(singleBlockManager);
        return result;
    }

    public static interface MeanFunction
    extends BiFunction<Double, Integer, Double> {
        @Override
        public Double apply(Double var1, Integer var2);
    }
}

