/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.sentencetype;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.classifier.maxent.MaxEntClassifier;
import org.maochen.nlp.ml.vector.IVector;
import org.maochen.nlp.ml.vector.LabeledVector;
import org.maochen.nlp.parser.DTree;
import org.maochen.nlp.parser.IParser;
import org.maochen.nlp.parser.stanford.nn.StanfordNNDepParser;
import org.maochen.nlp.parser.stanford.pcfg.StanfordPCFGParser;
import org.maochen.nlp.sentencetype.FeatureExtractor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SentenceTypeClassifier {
    private static final Logger LOG = LoggerFactory.getLogger(SentenceTypeClassifier.class);
    private MaxEntClassifier maxEntClassifier = new MaxEntClassifier();
    private FeatureExtractor featureExtractor = new FeatureExtractor();
    private IParser parser;

    public void train(String trainFilePath) throws IOException {
        HashMap<String, String> para = new HashMap<String, String>(){
            {
                this.put("iterations", "120");
            }
        };
        this.maxEntClassifier.setParameter((Map<String, String>)para);
        this.parser.parse(".");
        ConcurrentHashMap depTreeCache = new ConcurrentHashMap();
        HashSet<String> trainingData = new HashSet<String>();
        try (BufferedReader br = new BufferedReader(new FileReader(trainFilePath));){
            String line2 = br.readLine();
            while (line2 != null) {
                trainingData.add(line2);
                line2 = br.readLine();
            }
        }
        LOG.info("Loaded Training data.");
        LOG.info("Generating parse tree.");
        trainingData.parallelStream().map(x -> {
            String sentence = x.split("\\t")[1];
            depTreeCache.put(sentence, this.parser.parse(sentence));
            return null;
        }).collect(Collectors.toSet());
        LOG.info("Generating feats");
        List<Tuple> trainingTuples = trainingData.stream().map(line -> {
            String sentence = line.split("\\t")[1];
            String label = line.split("\\t")[0];
            DTree parseTree = (DTree)depTreeCache.get(sentence);
            List<String> feats = this.featureExtractor.generateFeats(sentence, parseTree);
            String[] featsName = (String[])feats.stream().toArray(String[]::new);
            double[] feat = feats.stream().mapToDouble(x -> 1.0).toArray();
            LabeledVector labeledVector = new LabeledVector(feat);
            labeledVector.featsName = featsName;
            return new Tuple(1, (IVector)labeledVector, label);
        }).collect(Collectors.toList());
        LOG.info("Extracted Feats.");
        this.maxEntClassifier.train(trainingTuples);
    }

    public void persist(String modelPath) throws IOException {
        this.maxEntClassifier.persistModel(modelPath);
    }

    public void loadModel(InputStream modelPath) throws IOException {
        this.maxEntClassifier.loadModel(modelPath);
    }

    public Map<String, Double> predict(String sentence, DTree tree) {
        List<String> feats = this.featureExtractor.generateFeats(sentence, tree);
        String[] featsName = (String[])feats.stream().toArray(String[]::new);
        double[] feat = feats.stream().mapToDouble(x -> 1.0).toArray();
        LabeledVector vector = new LabeledVector(feat);
        vector.featsName = featsName;
        Tuple predict = new Tuple((IVector)vector);
        return this.maxEntClassifier.predict(predict);
    }

    public Map<String, Double> predict(String sentence) {
        return this.predict(sentence, this.parser.parse(sentence));
    }

    public SentenceTypeClassifier() {
        this(new StanfordNNDepParser());
    }

    public SentenceTypeClassifier(IParser parser) {
        this.parser = parser;
    }

    public static void main(String[] args) throws IOException {
        String sentence;
        String prefix = "/Users/Maochen/workspace/ameliang/ameliang/amelia-nlp/src/main/resources/models";
        String trainFilePath = "/Users/Maochen/workspace/nlp-service_training-data/sentence_type_corpus.txt";
        String modelPath = prefix + "/sent_type_model.dat";
        String pcfgModel = prefix + "/englishPCFG.ser.gz";
        SentenceTypeClassifier sentenceTypeClassifier = new SentenceTypeClassifier(new StanfordPCFGParser(pcfgModel, null, null));
        sentenceTypeClassifier.train(trainFilePath);
        sentenceTypeClassifier.persist(modelPath);
        sentenceTypeClassifier.loadModel(new FileInputStream(modelPath));
        Scanner scanner = new Scanner(System.in);
        System.out.println("Input Sentence:");
        while (!(sentence = scanner.nextLine()).equalsIgnoreCase("exit")) {
            Map<String, Double> result = sentenceTypeClassifier.predict(sentence);
            System.out.println(result);
            String type = result.entrySet().stream().max((e1, e2) -> ((Double)e1.getValue()).compareTo((Double)e2.getValue())).map(Map.Entry::getKey).orElse(null);
            System.out.println(StringUtils.capitalize((String)type));
        }
        scanner.close();
        System.exit(0);
    }
}

