/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.ml.util;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.maochen.nlp.ml.SequenceTuple;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.vector.IVector;
import org.maochen.nlp.ml.vector.LabeledVector;

public class TrainingDataUtils {
    public static List<Tuple> createBalancedTrainingData(List<Tuple> trainingData) {
        ArrayList<Tuple> copyTrainingData = new ArrayList<Tuple>(trainingData);
        Collections.shuffle(copyTrainingData);
        Map<String, Long> tagCount = trainingData.parallelStream().map(x -> new AbstractMap.SimpleImmutableEntry<String, Integer>(x.label, 1)).collect(Collectors.groupingBy(Map.Entry::getKey, Collectors.counting()));
        long minCount = tagCount.values().stream().min(Long::compareTo).get();
        Map<String, Integer> accumulateTagCount = tagCount.entrySet().stream().map(Map.Entry::getKey).map(x -> new AbstractMap.SimpleImmutableEntry<String, Integer>((String)x, 0)).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
        ListIterator iter = copyTrainingData.listIterator(copyTrainingData.size());
        while (iter.hasPrevious()) {
            Tuple tuple = (Tuple)iter.previous();
            int currentCount = accumulateTagCount.get(tuple.label);
            if ((long)currentCount < minCount) {
                accumulateTagCount.put(tuple.label, currentCount + 1);
                continue;
            }
            iter.remove();
        }
        return copyTrainingData;
    }

    public static Pair<List<Tuple>, List<Tuple>> splitData(List<Tuple> trainingData, double proportion) {
        if (proportion < 0.0 || proportion > 1.0) {
            throw new RuntimeException("Proportion should between 0.0 - 1.0");
        }
        if (proportion > 0.5) {
            proportion = 1.0 - proportion;
        }
        ArrayList smallList = new ArrayList();
        ArrayList largeList = new ArrayList();
        int smallListSize = (int)Math.floor(proportion * (double)trainingData.size());
        HashSet<Integer> indices = new HashSet<Integer>();
        for (int ct = 0; ct < smallListSize && trainingData.size() > indices.size(); ++ct) {
            int index = (int)(Math.random() * (double)(trainingData.size() - 1));
            while (indices.contains(index)) {
                index = (int)(Math.random() * (double)(trainingData.size() - 1));
            }
            indices.add(index);
        }
        smallList.addAll(indices.stream().map(trainingData::get).collect(Collectors.toList()));
        IntStream.range(0, trainingData.size()).filter(x -> !indices.contains(x)).forEach(i -> largeList.add(trainingData.get(i)));
        return new ImmutablePair(smallList, largeList);
    }

    public static List<SequenceTuple> readSeqFile(InputStream trainingFile, String delimiter, int tagCol) {
        ArrayList<SequenceTuple> data = new ArrayList<SequenceTuple>();
        try (BufferedReader br = new BufferedReader(new InputStreamReader(trainingFile));){
            String line = br.readLine();
            int tupleId = 0;
            int seqId = 0;
            SequenceTuple sequenceTuple = new SequenceTuple();
            sequenceTuple.entries = new ArrayList();
            sequenceTuple.id = seqId;
            while (line != null) {
                if (line.trim().isEmpty()) {
                    data.add(sequenceTuple);
                    tupleId = 0;
                    sequenceTuple = new SequenceTuple();
                    sequenceTuple.entries = new ArrayList();
                    sequenceTuple.id = ++seqId;
                } else {
                    String[] fields = line.trim().split(delimiter);
                    String[] feats = (String[])IntStream.range(0, fields.length).filter(i -> i != tagCol).mapToObj(i -> fields[i]).toArray(String[]::new);
                    LabeledVector v = new LabeledVector(feats);
                    sequenceTuple.entries.add(new Tuple(tupleId++, (IVector)v, fields[tagCol]));
                }
                line = br.readLine();
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return data;
    }
}

