/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.ml.classifier.hmm;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.maochen.nlp.ml.SequenceTuple;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.classifier.hmm.HMMModel;
import org.maochen.nlp.ml.classifier.hmm.Viterbi;
import org.maochen.nlp.ml.classifier.hmm.WordUtils;
import org.maochen.nlp.ml.vector.LabeledVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HMM {
    private static final Logger LOG = LoggerFactory.getLogger(HMM.class);
    public static final int WORD_INDEX = 0;
    protected static final String START = "<START>";
    protected static final String END = "<END>";
    private static Consumer<Map<String, Double>> normalize = map -> {
        double total = map.values().stream().mapToDouble(x -> x).sum();
        for (String key : map.keySet()) {
            map.put(key, (Double)map.get(key) / total);
        }
    };

    private static SequenceTuple getSequenceTuple(List<String> words, List<String> pos) {
        HashMap<Integer, List<String>> wordFeat = new HashMap<Integer, List<String>>();
        wordFeat.put(0, words);
        return new SequenceTuple(wordFeat, pos);
    }

    public static List<SequenceTuple> readTrainFile(String filename, String delimiter, int wordColIndex, int tagColIndex) {
        ArrayList<SequenceTuple> data = new ArrayList<SequenceTuple>();
        try (BufferedReader br = new BufferedReader(new FileReader(filename));){
            String line = br.readLine();
            ArrayList<String> words = new ArrayList<String>();
            ArrayList<String> pos = new ArrayList<String>();
            while (line != null) {
                if (line.trim().isEmpty()) {
                    if (!words.isEmpty() && !pos.isEmpty()) {
                        SequenceTuple tuple = HMM.getSequenceTuple(words, pos);
                        data.add(tuple);
                        words = new ArrayList();
                        pos = new ArrayList();
                    }
                } else {
                    String[] tp = line.split(delimiter);
                    String word = tp[wordColIndex];
                    words.add(WordUtils.normalizeWord(word));
                    pos.add(WordUtils.normalizeTag(tp[tagColIndex]));
                }
                line = br.readLine();
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return data;
    }

    static void normalizeEmission(HMMModel model) {
        model.emission.columnMap().values().stream().forEach(normalize);
        for (String tag : model.emission.columnKeySet()) {
            double minProb = model.emission.column((Object)tag).values().stream().min(Double::compareTo).orElse(0.0);
            model.emissionMin.put(tag, minProb);
        }
    }

    static void normalizeTrans(HMMModel model) {
        model.transition.rowMap().values().stream().forEach(normalize);
    }

    private static Pair<List<String>, List<String>> getXSeqOSeq(SequenceTuple seqTuple) {
        List words = seqTuple.entries.stream().map(entry -> ((LabeledVector)entry.vector).featsName[0]).collect(Collectors.toList());
        List tag = seqTuple.getLabel();
        words.add(0, START);
        words.add(END);
        tag.add(0, START);
        tag.add(END);
        return new ImmutablePair(words, (Object)tag);
    }

    public static HMMModel train(List<SequenceTuple> data) {
        HMMModel model = new HMMModel();
        for (SequenceTuple seqTuple : data) {
            Double ct;
            int i;
            Pair<List<String>, List<String>> wordTagPair = HMM.getXSeqOSeq(seqTuple);
            List words = (List)wordTagPair.getLeft();
            List tag = (List)wordTagPair.getRight();
            for (i = 0; i < words.size(); ++i) {
                ct = (Double)model.emission.get(words.get(i), tag.get(i));
                ct = ct == null ? 1.0 : ct + 1.0;
                model.emission.put(words.get(i), tag.get(i), (Object)ct);
            }
            for (i = 0; i < tag.size() - 1; ++i) {
                ct = (Double)model.transition.get(tag.get(i), tag.get(i + 1));
                ct = ct == null ? 1.0 : ct + 1.0;
                model.transition.put(tag.get(i), tag.get(i + 1), (Object)ct);
            }
        }
        HMM.normalizeEmission(model);
        HMM.normalizeTrans(model);
        return model;
    }

    public static List<String> viterbi(HMMModel model, String[] words) {
        return Viterbi.resolve(model, words);
    }

    public static Map<String, Double> eval(HMMModel model, String testFile, String delimiter, int wordColIndex, int tagColIndex, boolean print) {
        List<SequenceTuple> testData = HMM.readTrainFile(testFile, delimiter, wordColIndex, tagColIndex);
        int totalCount = 0;
        int errCount = 0;
        for (SequenceTuple sequenceTuple : testData) {
            String[] words = (String[])sequenceTuple.entries.stream().map(entry -> ((LabeledVector)entry.vector).featsName[0]).toArray(String[]::new);
            List<String> result = HMM.viterbi(model, words);
            for (int i = 0; i < result.size(); ++i) {
                ++totalCount;
                String expected = WordUtils.normalizeTag(((Tuple)sequenceTuple.entries.get((int)i)).label);
                String actual = WordUtils.normalizeTag(result.get(i));
                if (actual.startsWith(expected) || expected.startsWith(actual)) continue;
                if (print) {
                    LOG.info(words[i] + " exp: " + expected + " actual: " + result.get(i));
                }
                ++errCount;
            }
        }
        HashMap<String, Double> result = new HashMap<String, Double>();
        double accuracy = 1.0 - (double)errCount / (double)totalCount;
        if (print) {
            LOG.info("accuracy: " + errCount + "/" + totalCount + " -> " + String.format("%.2f", accuracy * 100.0) + "%");
        }
        result.put("accuracy", accuracy);
        return result;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public static HMMModel loadModel(String modelPath) {
        try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelPath));){
            HMMModel hMMModel = (HMMModel)ois.readObject();
            return hMMModel;
        }
        catch (IOException | ClassNotFoundException e) {
            LOG.error("Load model err.", (Throwable)e);
            return null;
        }
    }

    public static void saveModel(String modelPath, HMMModel model) {
        try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelPath));){
            oos.writeObject(model);
        }
        catch (IOException e) {
            LOG.error("Persist model err.", (Throwable)e);
        }
    }
}

