/*
 * Decompiled with CFR 0.152.
 */
package chalk.tools.postag;

import chalk.tools.dictionary.Dictionary;
import chalk.tools.postag.MutableTagDictionary;
import chalk.tools.postag.POSDictionary;
import chalk.tools.postag.POSEvaluator;
import chalk.tools.postag.POSModel;
import chalk.tools.postag.POSSample;
import chalk.tools.postag.POSTaggerEvaluationMonitor;
import chalk.tools.postag.POSTaggerFactory;
import chalk.tools.postag.POSTaggerME;
import chalk.tools.postag.TagDictionary;
import chalk.tools.util.ObjectStream;
import chalk.tools.util.TrainingParameters;
import chalk.tools.util.eval.CrossValidationPartitioner;
import chalk.tools.util.eval.Mean;
import chalk.tools.util.model.ModelType;
import chalk.tools.util.model.ModelUtil;
import java.io.File;
import java.io.IOException;

public class POSTaggerCrossValidator {
    private final String languageCode;
    private final TrainingParameters params;
    private Integer ngramCutoff;
    private Mean wordAccuracy = new Mean();
    private POSTaggerEvaluationMonitor[] listeners;
    private String factoryClassName;
    private POSTaggerFactory factory;
    private Integer tagdicCutoff = null;
    private File tagDictionaryFile;

    public POSTaggerCrossValidator(String string, TrainingParameters trainingParameters, File file, Integer n, Integer n2, String string2, POSTaggerEvaluationMonitor ... pOSTaggerEvaluationMonitorArray) {
        this.languageCode = string;
        this.params = trainingParameters;
        this.ngramCutoff = n;
        this.listeners = pOSTaggerEvaluationMonitorArray;
        this.factoryClassName = string2;
        this.tagdicCutoff = n2;
        this.tagDictionaryFile = file;
    }

    public POSTaggerCrossValidator(String string, TrainingParameters trainingParameters, POSTaggerFactory pOSTaggerFactory, POSTaggerEvaluationMonitor ... pOSTaggerEvaluationMonitorArray) {
        this.languageCode = string;
        this.params = trainingParameters;
        this.listeners = pOSTaggerEvaluationMonitorArray;
        this.factory = pOSTaggerFactory;
        this.ngramCutoff = null;
        this.tagdicCutoff = null;
    }

    public POSTaggerCrossValidator(String string, ModelType modelType, POSDictionary pOSDictionary, Dictionary dictionary, int n, int n2) {
        this(string, POSTaggerCrossValidator.create(modelType, n, n2), POSTaggerCrossValidator.create(dictionary, pOSDictionary), new POSTaggerEvaluationMonitor[0]);
    }

    public POSTaggerCrossValidator(String string, ModelType modelType, POSDictionary pOSDictionary, Dictionary dictionary) {
        this(string, POSTaggerCrossValidator.create(modelType, 5, 100), POSTaggerCrossValidator.create(dictionary, pOSDictionary), new POSTaggerEvaluationMonitor[0]);
    }

    public POSTaggerCrossValidator(String string, TrainingParameters trainingParameters, POSDictionary pOSDictionary, POSTaggerEvaluationMonitor ... pOSTaggerEvaluationMonitorArray) {
        this(string, trainingParameters, POSTaggerCrossValidator.create(null, pOSDictionary), pOSTaggerEvaluationMonitorArray);
    }

    public POSTaggerCrossValidator(String string, TrainingParameters trainingParameters, POSDictionary pOSDictionary, Integer n, POSTaggerEvaluationMonitor ... pOSTaggerEvaluationMonitorArray) {
        this(string, trainingParameters, POSTaggerCrossValidator.create(null, pOSDictionary), pOSTaggerEvaluationMonitorArray);
        this.ngramCutoff = n;
    }

    public POSTaggerCrossValidator(String string, TrainingParameters trainingParameters, POSDictionary pOSDictionary, Dictionary dictionary, POSTaggerEvaluationMonitor ... pOSTaggerEvaluationMonitorArray) {
        this(string, trainingParameters, POSTaggerCrossValidator.create(dictionary, pOSDictionary), pOSTaggerEvaluationMonitorArray);
    }

    public void evaluate(ObjectStream<POSSample> objectStream, int n) throws IOException {
        CrossValidationPartitioner<POSSample> crossValidationPartitioner = new CrossValidationPartitioner<POSSample>(objectStream, n);
        while (crossValidationPartitioner.hasNext()) {
            Object object;
            Dictionary dictionary;
            CrossValidationPartitioner.TrainingSampleStream<POSSample> trainingSampleStream = crossValidationPartitioner.next();
            if (this.factory == null) {
                this.factory = POSTaggerFactory.create(this.factoryClassName, null, null);
            }
            if ((dictionary = this.factory.getDictionary()) == null) {
                if (this.ngramCutoff != null) {
                    System.err.print("Building ngram dictionary ... ");
                    dictionary = POSTaggerME.buildNGramDictionary(trainingSampleStream, this.ngramCutoff);
                    trainingSampleStream.reset();
                    System.err.println("done");
                }
                this.factory.setDictionary(dictionary);
            }
            if (this.tagDictionaryFile != null && this.factory.getTagDictionary() == null) {
                this.factory.setTagDictionary(this.factory.createTagDictionary(this.tagDictionaryFile));
            }
            if (this.tagdicCutoff != null) {
                object = this.factory.getTagDictionary();
                if (object == null) {
                    object = this.factory.createEmptyTagDictionary();
                    this.factory.setTagDictionary((TagDictionary)object);
                }
                if (!(object instanceof MutableTagDictionary)) {
                    throw new IllegalArgumentException("Can't extend a TagDictionary that does not implement MutableTagDictionary.");
                }
                POSTaggerME.populatePOSDictionary(trainingSampleStream, (MutableTagDictionary)object, this.tagdicCutoff);
                trainingSampleStream.reset();
            }
            object = POSTaggerME.train(this.languageCode, trainingSampleStream, this.params, this.factory);
            POSEvaluator pOSEvaluator = new POSEvaluator(new POSTaggerME((POSModel)object), this.listeners);
            pOSEvaluator.evaluate(trainingSampleStream.getTestSampleStream());
            this.wordAccuracy.add(pOSEvaluator.getWordAccuracy(), pOSEvaluator.getWordCount());
            if (this.tagdicCutoff == null) continue;
            this.factory.setTagDictionary(null);
        }
    }

    public double getWordAccuracy() {
        return this.wordAccuracy.mean();
    }

    public long getWordCount() {
        return this.wordAccuracy.count();
    }

    private static TrainingParameters create(ModelType modelType, int n, int n2) {
        TrainingParameters trainingParameters = ModelUtil.createTrainingParameters(n2, n);
        trainingParameters.put("Algorithm", modelType.toString());
        return trainingParameters;
    }

    private static POSTaggerFactory create(Dictionary dictionary, TagDictionary tagDictionary) {
        return new POSTaggerFactory(dictionary, tagDictionary);
    }
}

