/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.app.chunker;

import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.tagger.maxent.MaxentTagger;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;
import java.util.Scanner;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.maochen.nlp.app.chunker.ChunkerFeatureExtractor;
import org.maochen.nlp.app.featextractor.IFeatureExtractor;
import org.maochen.nlp.ml.SequenceTuple;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.classifier.crfsuite.CRFClassifier;
import org.maochen.nlp.ml.util.TrainingDataUtils;
import org.maochen.nlp.ml.vector.IVector;
import org.maochen.nlp.ml.vector.LabeledVector;
import org.maochen.nlp.util.ValidationUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CRFChunker
extends CRFClassifier {
    private static final Logger LOG = LoggerFactory.getLogger(CRFChunker.class);
    public static String TRAIN_FILE_DELIMITER = "\t";
    public IFeatureExtractor featureExtractor;

    public void train(String trainFilePath) throws FileNotFoundException {
        List<SequenceTuple> trainingData = TrainingDataUtils.readSeqFile(new FileInputStream(new File(trainFilePath)), TRAIN_FILE_DELIMITER, 2);
        LOG.info("Loaded Training data.");
        LOG.info("Generating feats");
        trainingData.stream().forEach(seq -> {
            seq.entries = this.featureExtractor.extractFeat((SequenceTuple)seq);
        });
        LOG.info("Extracted Feats.");
        super.train(trainingData);
    }

    public SequenceTuple predict(String[] words, String[] pos) {
        SequenceTuple st = new SequenceTuple();
        st.entries = new ArrayList();
        for (int i = 0; i < words.length; ++i) {
            LabeledVector v = new LabeledVector(new String[]{words[i], pos[i]});
            st.entries.add(new Tuple((IVector)v));
        }
        st.entries = this.featureExtractor.extractFeat(st);
        List<Pair<String, Double>> result = super.predict(st);
        List tags = result.stream().map(Pair::getLeft).collect(Collectors.toList());
        for (int i = 0; i < tags.size(); ++i) {
            ((Tuple)st.entries.get((int)i)).label = (String)tags.get(i);
        }
        return st;
    }

    public void validate(String testFile) throws FileNotFoundException {
        List<SequenceTuple> testData = TrainingDataUtils.readSeqFile(new FileInputStream(new File(testFile)), TRAIN_FILE_DELIMITER, 2);
        int errCount = 0;
        int total = 0;
        for (SequenceTuple st : testData) {
            total += st.entries.size();
            ArrayList<String> expectedTags = new ArrayList<String>(st.getLabel());
            String[] words = (String[])st.entries.stream().map(x -> ((LabeledVector)x.vector).featsName[0]).toArray(String[]::new);
            String[] pos = (String[])st.entries.stream().map(x -> ((LabeledVector)x.vector).featsName[1]).toArray(String[]::new);
            st = this.predict(words, pos);
            boolean isThisSeqPrinted = false;
            for (int i = 0; i < expectedTags.size(); ++i) {
                if (((String)expectedTags.get(i)).equals(((Tuple)st.entries.get((int)i)).label)) continue;
                if (!isThisSeqPrinted) {
                    ValidationUtils.printSequenceTuple(st, expectedTags);
                    System.out.println("");
                    isThisSeqPrinted = true;
                }
                ++errCount;
            }
        }
        System.out.println("Err/Total:\t" + errCount + "/" + total);
        System.out.println("Accuracy:\t" + (1.0 - (double)errCount / (double)total) * 100.0 + "%");
    }

    public static void main(String[] args) throws IOException {
        CRFChunker chunker = new CRFChunker();
        chunker.featureExtractor = new ChunkerFeatureExtractor();
        TRAIN_FILE_DELIMITER = " ";
        String modelPath = "/Users/mguan/Desktop/chunker.crf.model";
        Properties para = new Properties();
        para.setProperty("model", modelPath);
        para.setProperty("algorithm", "l2sgd");
        para.setProperty("feature.possible_transitions", "1");
        para.setProperty("feature.possible_states", "1");
        chunker.setParameter(para);
        String trainFile = "/Users/mguan/workspace/nlp-training-data/corpora/CoNLL_Shared_Task/CoNLL_2000_Chunking/train.txt";
        chunker.train(trainFile);
        chunker.validate("/Users/mguan/workspace/nlp-training-data/corpora/CoNLL_Shared_Task/CoNLL_2000_Chunking/test.txt");
        MaxentTagger posTagger = new MaxentTagger("edu/stanford/nlp/models/pos-tagger/english-left3words/english-left3words-distsim.tagger");
        Scanner scan = new Scanner(System.in);
        String input = "";
        String quitRegex = "q|quit|exit";
        while (!input.matches(quitRegex)) {
            System.out.println("Please enter sentence:");
            input = scan.nextLine();
            if (input.trim().isEmpty() || input.matches(quitRegex)) continue;
            String[] words = input.split("\\s");
            List tokens = Arrays.stream(words).map(word -> {
                CoreLabel coreLabel = new CoreLabel();
                coreLabel.setWord(word);
                coreLabel.setOriginalText(word);
                coreLabel.setValue(word);
                return coreLabel;
            }).collect(Collectors.toList());
            List posList = posTagger.tagSentence(tokens);
            for (int i = 0; i < tokens.size(); ++i) {
                String pos = ((TaggedWord)posList.get(i)).tag();
                ((CoreLabel)tokens.get(i)).setTag(pos);
            }
            String[] pos = (String[])tokens.stream().map(CoreLabel::tag).toArray(String[]::new);
            SequenceTuple st = chunker.predict(words, pos);
            ValidationUtils.printSequenceTuple(st, null);
        }
    }
}

