/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.ml.util;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.maochen.nlp.ml.SequenceTuple;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.vector.FeatNamedVector;
import org.maochen.nlp.ml.vector.IVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TrainingDataUtils {
    private static final Logger LOG = LoggerFactory.getLogger(TrainingDataUtils.class);

    public static List<Tuple> createBalancedTrainingData(List<Tuple> trainingData) {
        Map<String, Long> tagCount = trainingData.parallelStream().map(x -> new AbstractMap.SimpleImmutableEntry<String, Integer>(x.label, 1)).collect(Collectors.groupingBy(Map.Entry::getKey, Collectors.counting()));
        Map.Entry minCountEntry = tagCount.entrySet().stream().min(Comparator.comparing(Map.Entry::getValue)).orElse(null);
        tagCount.clear();
        ArrayList<Tuple> newData = new ArrayList<Tuple>();
        for (Tuple t : trainingData) {
            String label = t.label;
            if (!tagCount.containsKey(label)) {
                tagCount.put(t.label, 0L);
            }
            if (tagCount.get(label) >= (Long)minCountEntry.getValue()) continue;
            tagCount.put(label, tagCount.get(label) + 1L);
            newData.add(t);
        }
        return newData;
    }

    private static Set<String> getSingleValFeat(Map<String, Map<Double, Integer>> featNameValues, int trainingDataSize) {
        return featNameValues.entrySet().stream().filter(e -> {
            List detailEntry = ((Map)e.getValue()).entrySet().stream().collect(Collectors.toList());
            if (detailEntry.size() == 1 && (Integer)((Map.Entry)detailEntry.get(0)).getValue() == 1) {
                return true;
            }
            if (detailEntry.size() == 1 && ((Integer)((Map.Entry)detailEntry.get(0)).getValue()).equals(trainingDataSize)) {
                return true;
            }
            if (detailEntry.size() == 2) {
                return (Integer)((Map.Entry)detailEntry.get(0)).getValue() == 1 || (Integer)((Map.Entry)detailEntry.get(1)).getValue() == 1;
            }
            return false;
        }).map(Map.Entry::getKey).collect(Collectors.toSet());
    }

    public static void reduceDimension(List<Tuple> trainingData) {
        HashMap<String, Map<Double, Integer>> featNameValues = new HashMap<String, Map<Double, Integer>>();
        for (Tuple t : trainingData) {
            double[] featValues = t.vector.getVector();
            String[] name = t.vector instanceof FeatNamedVector ? ((FeatNamedVector)t.vector).featsName : (String[])IntStream.range(0, t.vector.getVector().length).mapToObj(String::valueOf).toArray(String[]::new);
            for (int i = 0; i < featValues.length; ++i) {
                if (!featNameValues.containsKey(name[i])) {
                    featNameValues.put(name[i], new HashMap());
                }
                Map valMap = (Map)featNameValues.get(name[i]);
                int newCount = 0;
                if (valMap.containsKey(featValues[i])) {
                    newCount = (Integer)valMap.get(featValues[i]);
                }
                valMap.put(featValues[i], ++newCount);
            }
        }
        Set<String> singleValFeats = TrainingDataUtils.getSingleValFeat(featNameValues, trainingData.size());
        LOG.debug("Single value feats: ");
        LOG.debug(singleValFeats.toString().replaceAll(", ", System.lineSeparator()));
        for (Tuple t : trainingData) {
            Set<Object> indicesToBeRemoved;
            ArrayList<Double> featVal = new ArrayList<Double>();
            double[] originalVectorVal = t.vector.getVector();
            if (t.vector instanceof FeatNamedVector) {
                indicesToBeRemoved = new HashSet();
                String[] featName = ((FeatNamedVector)t.vector).featsName;
                for (int i = 0; i < featName.length; ++i) {
                    if (!singleValFeats.contains(featName[i])) continue;
                    indicesToBeRemoved.add(i);
                }
            } else {
                indicesToBeRemoved = singleValFeats.stream().map(Integer::parseInt).collect(Collectors.toSet());
            }
            for (int i = 0; i < originalVectorVal.length; ++i) {
                if (indicesToBeRemoved.contains(i)) continue;
                featVal.add(originalVectorVal[i]);
            }
            t.vector.setVector(featVal.stream().mapToDouble(x -> x).toArray());
            if (!(t.vector instanceof FeatNamedVector)) continue;
            ArrayList<String> featName = new ArrayList<String>();
            String[] originalFeatName = ((FeatNamedVector)t.vector).featsName;
            for (int i = 0; i < originalVectorVal.length; ++i) {
                if (indicesToBeRemoved.contains(i)) continue;
                featName.add(originalFeatName[i]);
            }
            ((FeatNamedVector)t.vector).featsName = (String[])featName.stream().toArray(String[]::new);
        }
    }

    public static Pair<List<Tuple>, List<Tuple>> splitData(List<Tuple> trainingData, double proportion) {
        if (proportion < 0.0 || proportion > 1.0) {
            throw new RuntimeException("Proportion should between 0.0 - 1.0");
        }
        if (proportion > 0.5) {
            proportion = 1.0 - proportion;
        }
        ArrayList smallList = new ArrayList();
        ArrayList largeList = new ArrayList();
        int smallListSize = (int)Math.floor(proportion * (double)trainingData.size());
        HashSet<Integer> indices = new HashSet<Integer>();
        for (int ct = 0; ct < smallListSize && trainingData.size() > indices.size(); ++ct) {
            int index = (int)(Math.random() * (double)(trainingData.size() - 1));
            while (indices.contains(index)) {
                index = (int)(Math.random() * (double)(trainingData.size() - 1));
            }
            indices.add(index);
        }
        smallList.addAll(indices.stream().map(trainingData::get).collect(Collectors.toList()));
        IntStream.range(0, trainingData.size()).filter(x -> !indices.contains(x)).forEach(i -> largeList.add(trainingData.get(i)));
        return new ImmutablePair(smallList, largeList);
    }

    public static List<SequenceTuple> readSeqFile(InputStream trainingFile, String delimiter, int tagCol) {
        ArrayList<SequenceTuple> data = new ArrayList<SequenceTuple>();
        try (BufferedReader br = new BufferedReader(new InputStreamReader(trainingFile));){
            String line = br.readLine();
            int tupleId = 0;
            int seqId = 0;
            SequenceTuple sequenceTuple = new SequenceTuple();
            sequenceTuple.entries = new ArrayList();
            sequenceTuple.id = seqId;
            while (line != null) {
                if (line.trim().isEmpty()) {
                    data.add(sequenceTuple);
                    tupleId = 0;
                    sequenceTuple = new SequenceTuple();
                    sequenceTuple.entries = new ArrayList();
                    sequenceTuple.id = ++seqId;
                } else {
                    String[] fields = line.trim().split(delimiter);
                    String[] feats = (String[])IntStream.range(0, fields.length).filter(i -> i != tagCol).mapToObj(i -> fields[i]).toArray(String[]::new);
                    FeatNamedVector v = new FeatNamedVector(feats);
                    sequenceTuple.entries.add(new Tuple(tupleId++, (IVector)v, fields[tagCol]));
                }
                line = br.readLine();
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return data;
    }
}

