/*
 * Decompiled with CFR 0.152.
 */
package optbinning;

import com.google.common.math.DoubleMath;
import java.util.ArrayList;
import java.util.List;
import optbinning.OptimalBinning;
import optbinning.OptimalBinningUtil;
import org.jpmml.converter.CMatrixUtil;

public class MulticlassOptimalBinning
extends OptimalBinning {
    public MulticlassOptimalBinning(String module, String name) {
        super(module, name);
    }

    @Override
    public List<Double> getCategoriesOut() {
        String metric = this.getMetric();
        Integer numberOfClasses = this.getNumberOfClasses();
        List<Integer> numberOfEvents = this.getNumberOfEvents();
        switch (metric) {
            case "mean_woe": {
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        int cols = numberOfClasses;
        int rows = numberOfEvents.size() / numberOfClasses;
        ArrayList<List> eventCountsByColumn = new ArrayList<List>();
        for (int col = 0; col < cols; ++col) {
            List eventCounts = CMatrixUtil.getColumn(numberOfEvents, (int)rows, (int)cols, (int)col);
            eventCountsByColumn.add(eventCounts);
        }
        ArrayList<Integer> numberOfRecords = new ArrayList<Integer>();
        for (int row = 0; row < rows; ++row) {
            List eventCounts = CMatrixUtil.getRow(numberOfEvents, (int)rows, (int)cols, (int)row);
            numberOfRecords.add(OptimalBinningUtil.sumExact(eventCounts));
        }
        ArrayList nonEventCountsByColumn = new ArrayList();
        for (int col = 0; col < cols; ++col) {
            ArrayList<Integer> nonEventCounts = new ArrayList<Integer>();
            List eventCounts = (List)eventCountsByColumn.get(col);
            for (int row = 0; row < rows; ++row) {
                nonEventCounts.add((Integer)numberOfRecords.get(row) - (Integer)eventCounts.get(row));
            }
            nonEventCountsByColumn.add(nonEventCounts);
        }
        ArrayList woesByColumn = new ArrayList();
        for (int col = 0; col < cols; ++col) {
            ArrayList<Double> woes = new ArrayList<Double>();
            List eventCounts = (List)eventCountsByColumn.get(col);
            List nonEventCounts = (List)nonEventCountsByColumn.get(col);
            double constant = (double)OptimalBinningUtil.sumExact(eventCounts) / (double)OptimalBinningUtil.sumExact(nonEventCounts);
            for (int row = 0; row < rows; ++row) {
                double eventRate = (double)((Integer)eventCounts.get(row)).intValue() / (double)((Integer)numberOfRecords.get(row)).intValue();
                double woe = Math.log((1.0 / eventRate - 1.0) * constant);
                if (Double.isNaN(woe)) {
                    woe = 0.0;
                }
                woes.add(woe);
            }
            woesByColumn.add(woes);
        }
        ArrayList<Double> result = new ArrayList<Double>();
        for (int row = 0; row < rows; ++row) {
            ArrayList woesByRow = new ArrayList();
            for (int col = 0; col < cols; ++col) {
                List woes = (List)woesByColumn.get(col);
                woesByRow.add(woes.get(row));
            }
            result.add(DoubleMath.mean(woesByRow));
        }
        return result;
    }

    @Override
    public String getDefaultMetric() {
        return "mean_woe";
    }

    public Integer getNumberOfClasses() {
        return this.getInteger("_n_classes");
    }

    @Override
    public List<Integer> getNumberOfEvents() {
        return this.getIntegerArray("_n_event");
    }
}

