/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.wordcorrection;

import edu.stanford.nlp.ling.CoreLabel;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.maochen.nlp.datastructure.DoubleKeyMap;
import org.maochen.nlp.parser.stanford.StanfordParser;

public class SingleWordCorrection {
    private Model model = new Model();

    private Map<String, Double> distance1Generation(String word) {
        if (word == null || word.length() < 1) {
            throw new RuntimeException("Input word Error: " + word);
        }
        HashMap<String, Double> result = new HashMap<String, Double>();
        for (int i = 0; i < word.length(); ++i) {
            int j;
            String prev = word.substring(0, i);
            String last = word.substring(i + 1, word.length());
            result.put(prev + last, 1.0);
            if (i + 1 < word.length()) {
                prev = word.substring(0, i);
                last = word.substring(i + 2, word.length());
                String trans = prev + word.charAt(i + 1) + word.charAt(i) + last;
                result.put(trans, 1.0);
            }
            prev = word.substring(0, i);
            last = word.substring(i + 1, word.length());
            for (j = 0; j < 26; ++j) {
                result.put(prev + (char)(j + 97) + last, 1.0);
            }
            prev = word.substring(0, i);
            last = word.substring(i + 1, word.length());
            for (j = 0; j < 26; ++j) {
                result.put(prev + (char)(j + 97) + word.charAt(i) + last, 1.0);
                result.put(prev + word.charAt(i) + (char)(j + 97) + last, 1.0);
            }
        }
        result.remove(word);
        return result;
    }

    private Map<String, Double> errWordgenerating(String word) {
        Map<String, Double> oneDistance = this.distance1Generation(word);
        return oneDistance;
    }

    public void buildModel(String wordFileName) throws IOException {
        String str;
        BufferedReader br = new BufferedReader(new FileReader(new File(wordFileName)));
        while ((str = br.readLine()) != null) {
            List tokens = StanfordParser.stanfordTokenize((String)str).stream().map(CoreLabel::originalText).collect(Collectors.toList());
            for (String word : tokens) {
                double count = this.model.wordProbability.containsKey(word) ? this.model.wordProbability.get(word) : 0.0;
                this.model.wordProbability.put(word, count += 1.0);
            }
        }
        br.close();
        this.model.wordProbability.remove("");
        this.model.normalizeWordProbability();
    }

    public String predict(String wrongWord) {
        if (this.model.wordProbability.containsKey(wrongWord)) {
            return wrongWord;
        }
        Map<String, Double> possibleCorrectWordMap = this.errWordgenerating(wrongWord);
        possibleCorrectWordMap.keySet().stream().filter(this.model.wordProbability::containsKey).forEach(possibleCorrectWord -> {
            Map<String, Double> errWordMap = this.errWordgenerating((String)possibleCorrectWord);
            errWordMap.keySet().removeAll(this.model.wordProbability.keySet());
            for (String errWord : errWordMap.keySet()) {
                this.model.derivedWordProbability.put((String)possibleCorrectWord, errWord, errWordMap.get(errWord));
            }
        });
        if (this.model.derivedWordProbability.size() == 0) {
            throw new RuntimeException("No Correction Suggestion");
        }
        this.model.normalizeDerivedWordProbability();
        double argmaxProb = 0.0;
        String argmaxWord = "";
        Map<String, Double> pwc = this.model.derivedWordProbability.column(wrongWord);
        for (String correctWord : pwc.keySet()) {
            System.out.println("[predict] Possible: " + correctWord + "=" + pwc.get(correctWord) + "\t|\twordProb: " + this.model.wordProbability.get(correctWord));
            double localarg = pwc.get(correctWord) * this.model.wordProbability.get(correctWord);
            if (!(localarg >= argmaxProb)) continue;
            argmaxProb = localarg;
            argmaxWord = correctWord;
        }
        return argmaxWord;
    }

    public void persistModel(String filename) throws IOException {
        this.model.persist(filename);
    }

    public void restoreModel(String filename) throws IOException, ClassNotFoundException {
        this.model.restore(filename);
    }

    public SingleWordCorrection() {
        if (this.model == null) {
            this.model = new Model();
        }
    }

    public static void main(String[] args) throws IOException {
        String path = SingleWordCorrection.class.getClassLoader().getResource("the_adventures_of_sherlock_holmes.txt").getFile();
        SingleWordCorrection swc = new SingleWordCorrection();
        swc.buildModel(path);
        String word = "prob";
        Long start = System.currentTimeMillis();
        String predict = swc.predict(word);
        Long end = System.currentTimeMillis();
        System.out.println(word + "->" + predict);
        System.out.println("\nPredict Time Elapse: " + (end - start) + "ms");
    }

    static class Model
    implements Serializable {
        private static final long serialVersionUID = 1L;
        public Map<String, Double> wordProbability = new HashMap<String, Double>();
        public transient DoubleKeyMap<String, String, Double> derivedWordProbability = new DoubleKeyMap();

        Model() {
        }

        public void normalizeWordProbability() {
            double totalCount = 0.0;
            for (String key : this.wordProbability.keySet()) {
                totalCount += this.wordProbability.get(key).doubleValue();
            }
            for (String key : this.wordProbability.keySet()) {
                double updatedCount = this.wordProbability.get(key) / totalCount;
                this.wordProbability.put(key, updatedCount);
            }
        }

        public void normalizeDerivedWordProbability() {
            for (String key1 : this.derivedWordProbability.rowKeySet()) {
                double totalCount = 0.0;
                Map<String, Double> k2map = this.derivedWordProbability.row(key1);
                for (String str : k2map.keySet()) {
                    totalCount += k2map.get(str).doubleValue();
                }
                for (String key2 : this.derivedWordProbability.row(key1).keySet()) {
                    this.derivedWordProbability.put(key1, key2, this.derivedWordProbability.get(key1, key2) / totalCount);
                }
            }
        }

        public void persist(String filename) throws IOException {
            FileOutputStream fos = new FileOutputStream(filename);
            ObjectOutputStream oos = new ObjectOutputStream(fos);
            oos.writeObject(this);
            oos.flush();
            oos.close();
        }

        public void restore(String filename) throws IOException, ClassNotFoundException {
            FileInputStream fis = new FileInputStream(filename);
            ObjectInputStream ois = new ObjectInputStream(fis);
            Model reload = (Model)ois.readObject();
            this.wordProbability = reload.wordProbability;
            ois.close();
        }
    }
}

