/*
 * Decompiled with CFR 0.152.
 */
package eus.ixa.ixa.pipe.pos.train;

import eus.ixa.ixa.pipe.pos.train.Flags;
import eus.ixa.ixa.pipe.pos.train.InputOutputUtils;
import eus.ixa.ixa.pipe.pos.train.Trainer;
import java.io.File;
import java.io.IOException;
import opennlp.tools.cmdline.TerminateToolException;
import opennlp.tools.dictionary.Dictionary;
import opennlp.tools.postag.MutableTagDictionary;
import opennlp.tools.postag.POSEvaluator;
import opennlp.tools.postag.POSModel;
import opennlp.tools.postag.POSSample;
import opennlp.tools.postag.POSTaggerEvaluationMonitor;
import opennlp.tools.postag.POSTaggerFactory;
import opennlp.tools.postag.POSTaggerME;
import opennlp.tools.postag.TagDictionary;
import opennlp.tools.postag.WordTagSampleStream;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;

public abstract class AbstractTrainer
implements Trainer {
    private final String lang;
    private final ObjectStream<POSSample> trainSamples;
    private final ObjectStream<POSSample> testSamples;
    private WordTagSampleStream dictSamples;
    private final int dictCutOff;
    private final int ngramCutOff;
    private POSTaggerFactory posTaggerFactory;

    public AbstractTrainer(TrainingParameters params) throws IOException {
        this.lang = Flags.getLanguage(params);
        String trainData = Flags.getDataSet("TrainSet", params);
        String testData = Flags.getDataSet("TestSet", params);
        ObjectStream<String> trainStream = InputOutputUtils.readFileIntoMarkableStreamFactory(trainData);
        this.trainSamples = new WordTagSampleStream(trainStream);
        ObjectStream<String> testStream = InputOutputUtils.readFileIntoMarkableStreamFactory(testData);
        this.testSamples = new WordTagSampleStream(testStream);
        ObjectStream<String> dictStream = InputOutputUtils.readFileIntoMarkableStreamFactory(trainData);
        this.setDictSamples(new WordTagSampleStream(dictStream));
        this.dictCutOff = Flags.getAutoDictFeatures(params);
        this.ngramCutOff = Flags.getNgramDictFeatures(params);
    }

    @Override
    public final POSModel train(TrainingParameters params) {
        if (this.getPosTaggerFactory() == null) {
            throw new IllegalStateException("Classes derived from AbstractTrainer must  create a POSTaggerFactory features!");
        }
        POSModel trainedModel = null;
        POSEvaluator posEvaluator = null;
        try {
            trainedModel = POSTaggerME.train(this.lang, this.trainSamples, params, this.getPosTaggerFactory());
            POSTaggerME posTagger = new POSTaggerME(trainedModel);
            posEvaluator = new POSEvaluator(posTagger, new POSTaggerEvaluationMonitor[0]);
            posEvaluator.evaluate(this.testSamples);
        }
        catch (IOException e) {
            System.err.println("IO error while loading traing and test sets!");
            e.printStackTrace();
            System.exit(1);
        }
        System.out.println("Final result: " + posEvaluator.getWordAccuracy());
        return trainedModel;
    }

    protected final void createTagDictionary(String dictPath) {
        if (!dictPath.equalsIgnoreCase("off")) {
            try {
                this.getPosTaggerFactory().setTagDictionary(this.getPosTaggerFactory().createTagDictionary(new File(dictPath)));
            }
            catch (IOException e) {
                throw new TerminateToolException(-1, "IO error while loading POS Dictionary: " + e.getMessage(), e);
            }
        }
    }

    protected final void createAutomaticDictionary(ObjectStream<POSSample> aDictSamples, int aDictCutOff) {
        if (aDictCutOff != -1) {
            try {
                TagDictionary dict = this.getPosTaggerFactory().getTagDictionary();
                if (dict == null) {
                    dict = this.getPosTaggerFactory().createEmptyTagDictionary();
                    this.getPosTaggerFactory().setTagDictionary(dict);
                }
                if (!(dict instanceof MutableTagDictionary)) {
                    throw new IllegalArgumentException("Can't extend a POSDictionary that does not implement MutableTagDictionary.");
                }
                POSTaggerME.populatePOSDictionary(aDictSamples, (MutableTagDictionary)dict, aDictCutOff);
                this.dictSamples.reset();
            }
            catch (IOException e) {
                throw new TerminateToolException(-1, "IO error while creating/extending POS Dictionary: " + e.getMessage(), e);
            }
        }
    }

    protected final Dictionary createNgramDictionary(ObjectStream<POSSample> aDictSamples, int aNgramCutoff) {
        Dictionary ngramDict = null;
        if (aNgramCutoff != -1) {
            System.err.print("Building ngram dictionary ... ");
            try {
                ngramDict = POSTaggerME.buildNGramDictionary(aDictSamples, aNgramCutoff);
                this.dictSamples.reset();
            }
            catch (IOException e) {
                throw new TerminateToolException(-1, "IO error while building NGram Dictionary: " + e.getMessage(), e);
            }
            System.err.println("done");
        }
        return ngramDict;
    }

    protected final WordTagSampleStream getDictSamples() {
        return this.dictSamples;
    }

    protected final void setDictSamples(WordTagSampleStream aDictSamples) {
        this.dictSamples = aDictSamples;
    }

    protected final POSTaggerFactory getPosTaggerFactory() {
        return this.posTaggerFactory;
    }

    protected final void setPosTaggerFactory(POSTaggerFactory aPosTaggerFactory) {
        this.posTaggerFactory = aPosTaggerFactory;
    }

    protected final Integer getDictCutOff() {
        return this.dictCutOff;
    }

    protected final Integer getNgramDictCutOff() {
        return this.ngramCutOff;
    }
}

