/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.ml.classifier.hmm;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
import org.maochen.nlp.ml.SequenceTuple;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.classifier.hmm.HMMModel;
import org.maochen.nlp.ml.classifier.hmm.Viterbi;
import org.maochen.nlp.ml.classifier.hmm.WordUtils;
import org.maochen.nlp.ml.vector.LabeledVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HMM {
    private static final Logger LOG = LoggerFactory.getLogger(HMM.class);
    public static final int WORD_INDEX = 0;
    protected static final String START = "<START>";
    protected static final String END = "<END>";

    private static SequenceTuple getSequenceTuple(List<String> words, List<String> pos) {
        HashMap<Integer, List<String>> wordFeat = new HashMap<Integer, List<String>>();
        wordFeat.put(0, words);
        return new SequenceTuple(wordFeat, pos);
    }

    public static List<SequenceTuple> readTrainFile(String filename, String delimiter, int wordColIndex, int tagColIndex) {
        ArrayList<SequenceTuple> data = new ArrayList<SequenceTuple>();
        try (BufferedReader br = new BufferedReader(new FileReader(filename));){
            String line = br.readLine();
            ArrayList<String> words = new ArrayList<String>();
            ArrayList<String> pos = new ArrayList<String>();
            while (line != null) {
                if (line.trim().isEmpty()) {
                    if (!words.isEmpty() && !pos.isEmpty()) {
                        SequenceTuple tuple = HMM.getSequenceTuple(words, pos);
                        data.add(tuple);
                        words = new ArrayList();
                        pos = new ArrayList();
                    }
                } else {
                    String[] tp = line.split(delimiter);
                    String word = tp[wordColIndex];
                    words.add(WordUtils.normalizeWord(word));
                    pos.add(WordUtils.normalizeTag(tp[tagColIndex]));
                }
                line = br.readLine();
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return data;
    }

    public static void normalizeEmission(HMMModel model) {
        model.emission.columnMap().values().parallelStream().map(cols -> {
            double total = cols.values().stream().mapToDouble(x -> x).sum();
            for (String row : cols.keySet()) {
                cols.put(row, (Double)cols.get(row) / total);
            }
            return null;
        }).collect(Collectors.toSet());
        for (String tag : model.emission.columnKeySet()) {
            double minProb = model.emission.column((Object)tag).values().stream().min(Double::compareTo).orElse(0.0);
            model.emissionMin.put(tag, minProb);
        }
    }

    public static void normalizeTrans(HMMModel model) {
        model.transition.rowMap().values().parallelStream().map(rows -> {
            double total = rows.values().stream().mapToDouble(x -> x).sum();
            for (String col : rows.keySet()) {
                rows.put(col, (Double)rows.get(col) / total);
            }
            return null;
        }).collect(Collectors.toSet());
    }

    public static HMMModel train(List<SequenceTuple> data) {
        HMMModel model = new HMMModel();
        for (SequenceTuple seqTuple : data) {
            Double ct;
            int i;
            List words = seqTuple.entries.stream().map(entry -> ((LabeledVector)entry.vector).featsName[0]).collect(Collectors.toList());
            List tag = seqTuple.getLabel();
            words.add(0, START);
            words.add(END);
            tag.add(0, START);
            tag.add(END);
            for (i = 0; i < words.size(); ++i) {
                ct = (Double)model.emission.get(words.get(i), tag.get(i));
                ct = ct == null ? 1.0 : ct + 1.0;
                model.emission.put(words.get(i), tag.get(i), (Object)ct);
            }
            for (i = 0; i < seqTuple.entries.size() - 1; ++i) {
                ct = (Double)model.transition.get(tag.get(i), tag.get(i + 1));
                ct = ct == null ? 1.0 : ct + 1.0;
                model.transition.put(tag.get(i), tag.get(i + 1), (Object)ct);
            }
        }
        HMM.normalizeEmission(model);
        HMM.normalizeTrans(model);
        return model;
    }

    public static List<String> viterbi(HMMModel model, String[] words) {
        return Viterbi.resolve(model, words);
    }

    public static void eval(HMMModel model, String testFile, String delimiter, int wordColIndex, int tagColIndex) {
        List<SequenceTuple> testData = HMM.readTrainFile(testFile, delimiter, wordColIndex, tagColIndex);
        int totalCount = 0;
        int errCount = 0;
        for (SequenceTuple sequenceTuple : testData) {
            String[] words = (String[])sequenceTuple.entries.stream().map(entry -> ((LabeledVector)entry.vector).featsName[0]).toArray(String[]::new);
            List<String> result = HMM.viterbi(model, words);
            for (int i = 0; i < result.size(); ++i) {
                ++totalCount;
                String expected = WordUtils.normalizeTag(((Tuple)sequenceTuple.entries.get((int)i)).label);
                String actual = WordUtils.normalizeTag(result.get(i));
                if (actual.startsWith(expected) || expected.startsWith(actual)) continue;
                System.out.println(words[i] + " exp: " + expected + " actual: " + result.get(i));
                ++errCount;
            }
        }
        double accurancy = (1.0 - (double)errCount / (double)totalCount) * 100.0;
        System.out.println("accurancy: " + errCount + "/" + totalCount + " -> " + String.format("%.2f", accurancy) + "%");
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public static HMMModel loadModel(String modelPath) {
        try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelPath));){
            HMMModel hMMModel = (HMMModel)ois.readObject();
            return hMMModel;
        }
        catch (IOException | ClassNotFoundException e) {
            LOG.error("Load model err.", (Throwable)e);
            return null;
        }
    }

    public static void saveModel(String modelPath, HMMModel model) {
        try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelPath));){
            oos.writeObject(model);
        }
        catch (IOException e) {
            LOG.error("Persist model err.", (Throwable)e);
        }
    }

    public static void main(String[] args) throws InterruptedException {
        String prefix = "/Users/mguan/Dropbox/Course/Natural Lang Processing/HW/HW4_POSTagger_HMM/Homework4_corpus/POSData";
        List<SequenceTuple> data = HMM.readTrainFile(prefix + "/development.pos", "\t", 0, 1);
        List<SequenceTuple> data2 = HMM.readTrainFile(prefix + "/training.pos", "\t", 0, 1);
        data.addAll(data2);
        HMMModel model = HMM.train(data);
        HMM.eval(model, prefix + "/training.pos", "\t", 0, 1);
        String str = "The quick brown fox jumped over the lazy dog .";
        List<String> result = HMM.viterbi(model, str.split("\\s"));
        System.out.println(str);
        System.out.println(result);
    }
}

