/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.InfoGain;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;

public class FeatureConstraintUtil {
    private static Logger logger = MalletLogger.getLogger(FeatureConstraintUtil.class.getName());

    public static HashMap<Integer, double[][]> readRangeConstraintsFromFile(String filename, InstanceList data) {
        HashMap<Integer, double[][]> constraints = new HashMap<Integer, double[][]>();
        int li = 0;
        while (li < data.getTargetAlphabet().size()) {
            System.err.println(data.getTargetAlphabet().lookupObject(li));
            ++li;
        }
        try {
            BufferedReader reader = new BufferedReader(new FileReader(filename));
            String line = reader.readLine();
            while (line != null) {
                String[] split = line.split("\\s+");
                String featureName = split[0];
                int featureIndex = data.getDataAlphabet().lookupIndex(featureName, false);
                if (featureIndex == -1) {
                    throw new RuntimeException("Feature " + featureName + " not found in the alphabet!");
                }
                double[][] probs = new double[data.getTargetAlphabet().size()][2];
                int i = 0;
                while (i < probs.length) {
                    Arrays.fill(probs[i], Double.NEGATIVE_INFINITY);
                    ++i;
                }
                int index = 1;
                while (index < split.length) {
                    String[] labelSplit = split[index].split(":");
                    int li2 = data.getTargetAlphabet().lookupIndex(labelSplit[0], false);
                    assert (li2 != -1) : labelSplit[0];
                    if (labelSplit[1].contains(",")) {
                        String[] rangeSplit = labelSplit[1].split(",");
                        double lower = Double.parseDouble(rangeSplit[0]);
                        double upper = Double.parseDouble(rangeSplit[1]);
                        probs[li2][0] = lower;
                        probs[li2][1] = upper;
                    } else {
                        double prob;
                        probs[li2][0] = prob = Double.parseDouble(labelSplit[1]);
                        probs[li2][1] = prob;
                    }
                    ++index;
                }
                constraints.put(featureIndex, probs);
                line = reader.readLine();
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return constraints;
    }

    public static HashMap<Integer, double[]> readConstraintsFromFile(String filename, InstanceList data) {
        if (FeatureConstraintUtil.testConstraintsFileIndexBased(filename)) {
            return FeatureConstraintUtil.readConstraintsFromFileIndex(filename, data);
        }
        return FeatureConstraintUtil.readConstraintsFromFileString(filename, data);
    }

    public static HashMap<Integer, double[]> readConstraintsFromFileString(String filename, InstanceList data) {
        HashMap<Integer, double[]> constraints = new HashMap<Integer, double[]>();
        File file = new File(filename);
        try {
            BufferedReader reader = new BufferedReader(new FileReader(file));
            String line = reader.readLine();
            while (line != null) {
                String[] split = line.split("\\s+");
                String featureName = split[0];
                int featureIndex = data.getDataAlphabet().lookupIndex(featureName, false);
                assert (split.length - 1 == data.getTargetAlphabet().size()) : String.valueOf(split.length) + " " + data.getTargetAlphabet().size();
                double[] probs = new double[split.length - 1];
                int index = 1;
                while (index < split.length) {
                    double prob;
                    String[] labelSplit = split[index].split(":");
                    int li = data.getTargetAlphabet().lookupIndex(labelSplit[0], false);
                    assert (li != -1) : "Label " + labelSplit[0] + " not found";
                    probs[li] = prob = Double.parseDouble(labelSplit[1]);
                    ++index;
                }
                constraints.put(featureIndex, probs);
                line = reader.readLine();
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return constraints;
    }

    public static HashMap<Integer, double[]> readConstraintsFromFileIndex(String filename, InstanceList data) {
        HashMap<Integer, double[]> constraints = new HashMap<Integer, double[]>();
        File file = new File(filename);
        try {
            BufferedReader reader = new BufferedReader(new FileReader(file));
            String line = reader.readLine();
            while (line != null) {
                String[] split = line.split("\\s+");
                int featureIndex = Integer.parseInt(split[0]);
                assert (split.length - 1 == data.getTargetAlphabet().size());
                double[] probs = new double[split.length - 1];
                int index = 1;
                while (index < split.length) {
                    double prob;
                    probs[index - 1] = prob = Double.parseDouble(split[index]);
                    ++index;
                }
                constraints.put(featureIndex, probs);
                line = reader.readLine();
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return constraints;
    }

    private static boolean testConstraintsFileIndexBased(String filename) {
        File file = new File(filename);
        String firstLine = "";
        try {
            BufferedReader reader = new BufferedReader(new FileReader(file));
            firstLine = reader.readLine();
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return !firstLine.contains(":");
    }

    public static ArrayList<Integer> selectFeaturesByInfoGain(InstanceList list, int numFeatures) {
        ArrayList<Integer> features = new ArrayList<Integer>();
        InfoGain infogain = new InfoGain(list);
        int rank = 0;
        while (rank < numFeatures) {
            features.add(infogain.getIndexAtRank(rank));
            ++rank;
        }
        return features;
    }

    public static ArrayList<Integer> selectTopLDAFeatures(int numSelFeatures, ParallelTopicModel lda, Alphabet alphabet) {
        ArrayList<Integer> features = new ArrayList<Integer>();
        Alphabet seqAlphabet = lda.getAlphabet();
        int numTopics = lda.getNumTopics();
        Object[][] sorted = lda.getTopWords(seqAlphabet.size());
        int pos = 0;
        while (pos < seqAlphabet.size()) {
            int ti = 0;
            while (ti < numTopics) {
                String feat = sorted[ti][pos].toString();
                int fi = alphabet.lookupIndex(feat, false);
                if (fi >= 0 && !features.contains(fi)) {
                    logger.info("Selected feature: " + feat);
                    features.add(fi);
                    if (features.size() == numSelFeatures) {
                        return features;
                    }
                }
                ++ti;
            }
            ++pos;
        }
        return features;
    }

    public static HashMap<Integer, double[]> setTargetsUsingData(InstanceList list, ArrayList<Integer> features) {
        return FeatureConstraintUtil.setTargetsUsingData(list, features, true);
    }

    public static HashMap<Integer, double[]> setTargetsUsingData(InstanceList list, ArrayList<Integer> features, boolean normalize) {
        return FeatureConstraintUtil.setTargetsUsingData(list, features, false, normalize);
    }

    public static HashMap<Integer, double[]> setTargetsUsingData(InstanceList list, ArrayList<Integer> features, boolean useValues, boolean normalize) {
        HashMap<Integer, double[]> constraints = new HashMap<Integer, double[]>();
        double[][] featureLabelCounts = FeatureConstraintUtil.getFeatureLabelCounts(list, useValues);
        int i = 0;
        while (i < features.size()) {
            int fi = features.get(i);
            if (fi != list.getDataAlphabet().size()) {
                double[] prob = featureLabelCounts[fi];
                if (normalize) {
                    MatrixOps.plusEquals(prob, 1.0E-8);
                    MatrixOps.timesEquals(prob, 1.0 / MatrixOps.sum(prob));
                }
                constraints.put(fi, prob);
            }
            ++i;
        }
        return constraints;
    }

    public static HashMap<Integer, double[]> setTargetsUsingHeuristic(HashMap<Integer, ArrayList<Integer>> labeledFeatures, int numLabels, double majorityProb) {
        HashMap<Integer, double[]> constraints = new HashMap<Integer, double[]>();
        for (int fi : labeledFeatures.keySet()) {
            ArrayList<Integer> labels = labeledFeatures.get(fi);
            constraints.put(fi, FeatureConstraintUtil.getHeuristicPrior(labels, numLabels, majorityProb));
        }
        return constraints;
    }

    public static HashMap<Integer, double[]> setTargetsUsingFeatureVoting(HashMap<Integer, ArrayList<Integer>> labeledFeatures, InstanceList trainingData) {
        HashMap<Integer, double[]> constraints = new HashMap<Integer, double[]>();
        int numLabels = trainingData.getTargetAlphabet().size();
        Iterator<Integer> keyIter = labeledFeatures.keySet().iterator();
        double[][] featureCounts = new double[labeledFeatures.size()][numLabels];
        int ii = 0;
        while (ii < trainingData.size()) {
            Instance instance = (Instance)trainingData.get(ii);
            FeatureVector fv = (FeatureVector)instance.getData();
            Labeling labeling = ((Instance)trainingData.get(ii)).getLabeling();
            double[] labelDist = new double[numLabels];
            if (labeling == null) {
                FeatureConstraintUtil.labelByVoting(labeledFeatures, instance, labelDist);
            } else {
                int li = labeling.getBestIndex();
                labelDist[li] = 1.0;
            }
            keyIter = labeledFeatures.keySet().iterator();
            int i = 0;
            while (keyIter.hasNext()) {
                int fi = keyIter.next();
                if (fv.location(fi) >= 0) {
                    int li = 0;
                    while (li < numLabels) {
                        double[] dArray = featureCounts[i];
                        int n = li;
                        dArray[n] = dArray[n] + labelDist[li] * fv.valueAtLocation(fv.location(fi));
                        ++li;
                    }
                }
                ++i;
            }
            ++ii;
        }
        keyIter = labeledFeatures.keySet().iterator();
        int i = 0;
        while (keyIter.hasNext()) {
            int fi = keyIter.next();
            MatrixOps.plusEquals(featureCounts[i], 1.0E-8);
            MatrixOps.timesEquals(featureCounts[i], 1.0 / MatrixOps.sum(featureCounts[i]));
            constraints.put(fi, featureCounts[i]);
            ++i;
        }
        return constraints;
    }

    public static HashMap<Integer, ArrayList<Integer>> labelFeatures(InstanceList list, ArrayList<Integer> features, boolean reject) {
        HashMap<Integer, ArrayList<Integer>> labeledFeatures = new HashMap<Integer, ArrayList<Integer>>();
        double[][] featureLabelCounts = FeatureConstraintUtil.getFeatureLabelCounts(list, true);
        int numLabels = list.getTargetAlphabet().size();
        int minRank = 100 * numLabels;
        InfoGain infogain = new InfoGain(list);
        double sum = 0.0;
        int rank = 0;
        while (rank < minRank) {
            sum += infogain.getValueAtRank(rank);
            ++rank;
        }
        double mean = sum / (double)minRank;
        int i = 0;
        while (i < features.size()) {
            block6: {
                ArrayList<Integer> labels;
                int fi;
                block8: {
                    int[] sortedIndices;
                    block7: {
                        block5: {
                            fi = features.get(i);
                            if (!reject || !(infogain.value(fi) < mean)) break block5;
                            logger.info("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
                            break block6;
                        }
                        double[] prob = featureLabelCounts[fi];
                        MatrixOps.plusEquals(prob, 1.0E-8);
                        MatrixOps.timesEquals(prob, 1.0 / MatrixOps.sum(prob));
                        sortedIndices = FeatureConstraintUtil.getMaxIndices(prob);
                        labels = new ArrayList<Integer>();
                        if (numLabels <= 2) break block7;
                        boolean discard = false;
                        double threshold = prob[sortedIndices[0]] / 2.0;
                        int li = 0;
                        while (li < numLabels) {
                            if (prob[li] > threshold) {
                                labels.add(li);
                            }
                            if (reject && labels.size() > numLabels / 2) {
                                logger.info("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
                                discard = true;
                                break;
                            }
                            ++li;
                        }
                        if (!discard) break block8;
                        break block6;
                    }
                    labels.add(sortedIndices[0]);
                }
                labeledFeatures.put(fi, labels);
            }
            ++i;
        }
        return labeledFeatures;
    }

    public static HashMap<Integer, ArrayList<Integer>> labelFeatures(InstanceList list, ArrayList<Integer> features) {
        return FeatureConstraintUtil.labelFeatures(list, features, true);
    }

    public static double[][] getFeatureLabelCounts(InstanceList list, boolean useValues) {
        int numFeatures = list.getDataAlphabet().size();
        int numLabels = list.getTargetAlphabet().size();
        double[][] featureLabelCounts = new double[numFeatures][numLabels];
        int ii = 0;
        while (ii < list.size()) {
            Instance instance = (Instance)list.get(ii);
            FeatureVector featureVector = (FeatureVector)instance.getData();
            int li = 0;
            while (li < numLabels) {
                double py = instance.getLabeling().value(li);
                int loc = 0;
                while (loc < featureVector.numLocations()) {
                    int fi = featureVector.indexAtLocation(loc);
                    double val = useValues ? featureVector.valueAtLocation(loc) : 1.0;
                    double[] dArray = featureLabelCounts[fi];
                    int n = li;
                    dArray[n] = dArray[n] + py * val;
                    ++loc;
                }
                ++li;
            }
            ++ii;
        }
        return featureLabelCounts;
    }

    private static double[] getHeuristicPrior(ArrayList<Integer> labeledFeatures, int numLabels, double majorityProb) {
        int numIndices = labeledFeatures.size();
        double[] dist = new double[numLabels];
        if (numIndices == numLabels) {
            int i = 0;
            while (i < dist.length) {
                dist[i] = 1.0 / (double)numLabels;
                ++i;
            }
            return dist;
        }
        double keywordProb = majorityProb / (double)numIndices;
        double otherProb = (1.0 - majorityProb) / (double)(numLabels - numIndices);
        int i = 0;
        while (i < labeledFeatures.size()) {
            int li = labeledFeatures.get(i);
            dist[li] = keywordProb;
            ++i;
        }
        int li = 0;
        while (li < numLabels) {
            if (dist[li] == 0.0) {
                dist[li] = otherProb;
            }
            ++li;
        }
        assert (Maths.almostEquals(MatrixOps.sum(dist), 1.0));
        return dist;
    }

    private static void labelByVoting(HashMap<Integer, ArrayList<Integer>> labeledFeatures, Instance instance, double[] scores) {
        int li;
        FeatureVector fv = (FeatureVector)instance.getData();
        int numFeatures = instance.getDataAlphabet().size() + 1;
        int[] numLabels = new int[instance.getTargetAlphabet().size()];
        Iterator<Integer> keyIterator = labeledFeatures.keySet().iterator();
        while (keyIterator.hasNext()) {
            ArrayList<Integer> majorityClassList = labeledFeatures.get(keyIterator.next());
            int i = 0;
            while (i < majorityClassList.size()) {
                int n = li = majorityClassList.get(i).intValue();
                numLabels[n] = numLabels[n] + 1;
                ++i;
            }
        }
        for (int next : labeledFeatures.keySet()) {
            assert (next < numFeatures);
            int loc = fv.location(next);
            if (loc < 0) continue;
            ArrayList<Integer> majorityClassList = labeledFeatures.get(next);
            int i = 0;
            while (i < majorityClassList.size()) {
                int li2;
                int n = li2 = majorityClassList.get(i).intValue();
                scores[n] = scores[n] + 1.0;
                ++i;
            }
        }
        double sum = MatrixOps.sum(scores);
        if (sum == 0.0) {
            MatrixOps.plusEquals(scores, 1.0);
            sum = MatrixOps.sum(scores);
        }
        li = 0;
        while (li < scores.length) {
            int n = li++;
            scores[n] = scores[n] / sum;
        }
    }

    private static int[] getMaxIndices(double[] x) {
        ArrayList<Element> list = new ArrayList<Element>();
        int i = 0;
        while (i < x.length) {
            Element element = new Element(i, x[i]);
            list.add(element);
            ++i;
        }
        Collections.sort(list);
        Collections.reverse(list);
        int[] sortedIndices = new int[x.length];
        int i2 = 0;
        while (i2 < x.length) {
            sortedIndices[i2] = ((Element)list.get(i2)).index;
            ++i2;
        }
        return sortedIndices;
    }

    private static class Element
    implements Comparable<Element> {
        private int index;
        private double value;

        public Element(int index, double value) {
            this.index = index;
            this.value = value;
        }

        @Override
        public int compareTo(Element element) {
            return Double.compare(this.value, element.value);
        }
    }
}

