/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.types;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import java.awt.geom.Point2D;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Hashtable;
import java.util.logging.Logger;

public class GainRatio
extends RankedFeatureVector {
    private static final Logger logger = MalletLogger.getLogger(GainRatio.class.getName());
    private static final long serialVersionUID = 1L;
    public static final double log2 = Math.log(2.0);
    double[] m_splitPoints;
    double m_baseEntropy;
    LabelVector m_baseLabelDistribution;
    int m_numSplitPointsForBestFeature;
    int m_minNumInsts;

    protected static Object[] calcGainRatios(InstanceList ilist, int[] instIndices, int minNumInsts) {
        int numInsts = instIndices.length;
        Alphabet dataDict = ilist.getDataAlphabet();
        LabelAlphabet targetDict = (LabelAlphabet)ilist.getTargetAlphabet();
        double[] targetCounts = new double[targetDict.size()];
        int ii = 0;
        while (ii < numInsts) {
            Instance inst = (Instance)ilist.get(instIndices[ii]);
            Labeling labeling = inst.getLabeling();
            double labelWeightSum = 0.0;
            int ll = 0;
            while (ll < labeling.numLocations()) {
                int li = labeling.indexAtLocation(ll);
                double labelWeight = labeling.valueAtLocation(ll);
                labelWeightSum += labelWeight;
                int n = li;
                targetCounts[n] = targetCounts[n] + labelWeight;
                ++ll;
            }
            assert (Maths.almostEquals(labelWeightSum, 1.0));
            ++ii;
        }
        double[] targetDistribution = new double[targetDict.size()];
        double baseEntropy = 0.0;
        int ci = 0;
        while (ci < targetDict.size()) {
            double p;
            targetDistribution[ci] = p = targetCounts[ci] / (double)numInsts;
            if (p > 0.0) {
                baseEntropy -= p * Math.log(p) / log2;
            }
            ++ci;
        }
        LabelVector baseLabelDistribution = new LabelVector(targetDict, targetDistribution);
        double infoGainSum = 0.0;
        int totalNumSplitPoints = 0;
        double[] passTestTargetCounts = new double[targetDict.size()];
        Hashtable[] featureToInfo = new Hashtable[dataDict.size()];
        int fi = 0;
        while (fi < dataDict.size()) {
            if ((fi + 1) % 1000 == 0) {
                logger.info("at feature " + (fi + 1) + " / " + dataDict.size());
            }
            featureToInfo[fi] = new Hashtable();
            Arrays.fill(passTestTargetCounts, 0.0);
            instIndices = GainRatio.sortInstances(ilist, instIndices, fi);
            int ii2 = 0;
            while (ii2 < numInsts - 1) {
                Instance inst = (Instance)ilist.get(instIndices[ii2]);
                Instance instPlusOne = (Instance)ilist.get(instIndices[ii2 + 1]);
                FeatureVector fv1 = (FeatureVector)inst.getData();
                FeatureVector fv2 = (FeatureVector)instPlusOne.getData();
                double lower = fv1.value(fi);
                double higher = fv2.value(fi);
                Labeling labeling = inst.getLabeling();
                int ll = 0;
                while (ll < labeling.numLocations()) {
                    int li = labeling.indexAtLocation(ll);
                    double labelWeight = labeling.valueAtLocation(ll);
                    int n = li;
                    passTestTargetCounts[n] = passTestTargetCounts[n] + labelWeight;
                    ++ll;
                }
                if (!Maths.almostEquals(lower, higher) && !inst.getLabeling().toString().equals(instPlusOne.getLabeling().toString())) {
                    double passProportion;
                    ++totalNumSplitPoints;
                    double splitPoint = (lower + higher) / 2.0;
                    double numPassInsts = ii2 + 1;
                    double numFailInsts = (double)numInsts - numPassInsts;
                    if (!(numPassInsts < (double)minNumInsts || numFailInsts < (double)minNumInsts || Maths.almostEquals(passProportion = numPassInsts / (double)numInsts, 0.0) || Maths.almostEquals(passProportion, 1.0))) {
                        double passEntropy = 0.0;
                        double failEntropy = 0.0;
                        int ci2 = 0;
                        while (ci2 < targetDict.size()) {
                            double failTestTargetCount;
                            double p;
                            if (numPassInsts > 0.0 && (p = passTestTargetCounts[ci2] / numPassInsts) > 0.0) {
                                passEntropy -= p * Math.log(p) / log2;
                            }
                            if (numFailInsts > 0.0 && (p = (failTestTargetCount = targetCounts[ci2] - passTestTargetCounts[ci2]) / numFailInsts) > 0.0) {
                                failEntropy -= p * Math.log(p) / log2;
                            }
                            ++ci2;
                        }
                        double gainDT = baseEntropy - passProportion * passEntropy - (1.0 - passProportion) * failEntropy;
                        infoGainSum += gainDT;
                        double splitDT = -passProportion * Math.log(passProportion) / log2 - (1.0 - passProportion) * Math.log(1.0 - passProportion) / log2;
                        double gainRatio = gainDT / splitDT;
                        featureToInfo[fi].put(new Double(splitPoint), new Point2D.Double(gainDT, gainRatio));
                    }
                }
                ++ii2;
            }
            ++fi;
        }
        double[] gainRatios = new double[dataDict.size()];
        double[] splitPoints = new double[dataDict.size()];
        int numSplitsForBestFeature = 0;
        if (totalNumSplitPoints == 0 || Maths.almostEquals(infoGainSum, 0.0)) {
            return new Object[]{gainRatios, splitPoints, new Double(baseEntropy), baseLabelDistribution, new Integer(numSplitsForBestFeature)};
        }
        double avgInfoGain = infoGainSum / (double)totalNumSplitPoints;
        double maxGainRatio = 0.0;
        double gainForMaxGainRatio = 0.0;
        int xxx = 0;
        int fi2 = 0;
        while (fi2 < dataDict.size()) {
            double featureMaxGainRatio = 0.0;
            double featureGainForMaxGainRatio = 0.0;
            double bestSplitPoint = Double.NaN;
            for (Object key : featureToInfo[fi2].keySet()) {
                Point2D.Double pt = (Point2D.Double)featureToInfo[fi2].get(key);
                double splitPoint = (Double)key;
                double infoGain = pt.getX();
                double gainRatio = pt.getY();
                if (infoGain >= avgInfoGain) {
                    if (!(gainRatio > featureMaxGainRatio) && (gainRatio != featureMaxGainRatio || !(infoGain > featureGainForMaxGainRatio))) continue;
                    featureMaxGainRatio = gainRatio;
                    featureGainForMaxGainRatio = infoGain;
                    bestSplitPoint = splitPoint;
                    continue;
                }
                ++xxx;
            }
            assert (bestSplitPoint != Double.NaN);
            gainRatios[fi2] = featureMaxGainRatio;
            splitPoints[fi2] = bestSplitPoint;
            if (featureMaxGainRatio > maxGainRatio || featureMaxGainRatio == maxGainRatio && featureGainForMaxGainRatio > gainForMaxGainRatio) {
                maxGainRatio = featureMaxGainRatio;
                gainForMaxGainRatio = featureGainForMaxGainRatio;
                numSplitsForBestFeature = featureToInfo[fi2].size();
            }
            ++fi2;
        }
        logger.info("label distrib:\n" + baseLabelDistribution);
        logger.info("base entropy=" + baseEntropy + ", info gain sum=" + infoGainSum + ", total num split points=" + totalNumSplitPoints + ", avg info gain=" + avgInfoGain + ", num splits with < avg gain=" + xxx);
        return new Object[]{gainRatios, splitPoints, new Double(baseEntropy), baseLabelDistribution, new Integer(numSplitsForBestFeature)};
    }

    public static int[] sortInstances(InstanceList ilist, int[] instIndices, int featureIndex) {
        ArrayList<Point2D.Double> list = new ArrayList<Point2D.Double>();
        int ii = 0;
        while (ii < instIndices.length) {
            Instance inst = (Instance)ilist.get(instIndices[ii]);
            FeatureVector fv = (FeatureVector)inst.getData();
            list.add(new Point2D.Double(instIndices[ii], fv.value(featureIndex)));
            ++ii;
        }
        Collections.sort(list, new Comparator(){

            public int compare(Object o1, Object o2) {
                Point2D.Double p1 = (Point2D.Double)o1;
                Point2D.Double p2 = (Point2D.Double)o2;
                if (p1.y == p2.y) {
                    if (!$assertionsDisabled && p1.x == p2.x) {
                        throw new AssertionError();
                    }
                    return p1.x > p2.x ? 1 : -1;
                }
                return p1.y > p2.y ? 1 : -1;
            }
        });
        int[] sorted = new int[instIndices.length];
        int i = 0;
        while (i < list.size()) {
            sorted[i] = (int)((Point2D.Double)list.get(i)).getX();
            ++i;
        }
        return sorted;
    }

    public static GainRatio createGainRatio(InstanceList ilist) {
        int[] instIndices = new int[ilist.size()];
        int ii = 0;
        while (ii < instIndices.length) {
            instIndices[ii] = ii;
            ++ii;
        }
        return GainRatio.createGainRatio(ilist, instIndices, 2);
    }

    public static GainRatio createGainRatio(InstanceList ilist, int[] instIndices, int minNumInsts) {
        Object[] objs = GainRatio.calcGainRatios(ilist, instIndices, minNumInsts);
        double[] gainRatios = (double[])objs[0];
        double[] splitPoints = (double[])objs[1];
        double baseEntropy = (Double)objs[2];
        LabelVector baseLabelDistribution = (LabelVector)objs[3];
        int numSplitPointsForBestFeature = (Integer)objs[4];
        return new GainRatio(ilist.getDataAlphabet(), gainRatios, splitPoints, baseEntropy, baseLabelDistribution, numSplitPointsForBestFeature, minNumInsts);
    }

    protected GainRatio(Alphabet dataAlphabet, double[] gainRatios, double[] splitPoints, double baseEntropy, LabelVector baseLabelDistribution, int numSplitPointsForBestFeature, int minNumInsts) {
        super(dataAlphabet, gainRatios);
        this.m_splitPoints = splitPoints;
        this.m_baseEntropy = baseEntropy;
        this.m_baseLabelDistribution = baseLabelDistribution;
        this.m_numSplitPointsForBestFeature = numSplitPointsForBestFeature;
        this.m_minNumInsts = minNumInsts;
    }

    public double getMaxValuedThreshold() {
        return this.getThresholdAtRank(0);
    }

    public double getThresholdAtRank(int rank) {
        int index = this.getIndexAtRank(rank);
        return this.m_splitPoints[index];
    }

    public double getBaseEntropy() {
        return this.m_baseEntropy;
    }

    public LabelVector getBaseLabelDistribution() {
        return this.m_baseLabelDistribution;
    }

    public int getNumSplitPointsForBestFeature() {
        return this.m_numSplitPointsForBestFeature;
    }
}

