/*
 * Decompiled with CFR 0.152.
 */
package eus.ixa.ixa.pipe.ml;

import eus.ixa.ixa.pipe.ml.features.XMLFeatureDescriptor;
import eus.ixa.ixa.pipe.ml.formats.CoNLL02Format;
import eus.ixa.ixa.pipe.ml.formats.CoNLL03Format;
import eus.ixa.ixa.pipe.ml.formats.LemmatizerFormat;
import eus.ixa.ixa.pipe.ml.formats.TabulatedFormat;
import eus.ixa.ixa.pipe.ml.resources.LoadModelResources;
import eus.ixa.ixa.pipe.ml.sequence.BilouCodec;
import eus.ixa.ixa.pipe.ml.sequence.BioCodec;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelSample;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelSampleTypeFilter;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerCodec;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerEvaluationMonitor;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerEvaluator;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerFactory;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerME;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerModel;
import eus.ixa.ixa.pipe.ml.utils.Flags;
import eus.ixa.ixa.pipe.ml.utils.IOUtils;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.Map;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;

public class SequenceLabelerTrainer {
    private final String lang;
    private final String trainData;
    private final String testData;
    private ObjectStream<SequenceLabelSample> trainSamples;
    private ObjectStream<SequenceLabelSample> testSamples;
    private final String corpusFormat;
    private String sequenceCodec;
    private final String clearTrainingFeatures;
    private final String clearEvaluationFeatures;
    private SequenceLabelerFactory nameClassifierFactory;

    public SequenceLabelerTrainer(TrainingParameters params) throws IOException {
        this.lang = Flags.getLanguage(params);
        this.clearTrainingFeatures = Flags.getClearTrainingFeatures(params);
        this.clearEvaluationFeatures = Flags.getClearEvaluationFeatures(params);
        this.corpusFormat = Flags.getCorpusFormat(params);
        this.trainData = params.getSettings().get("TrainSet");
        this.testData = params.getSettings().get("TestSet");
        this.trainSamples = SequenceLabelerTrainer.getSequenceStream(this.trainData, this.clearTrainingFeatures, this.corpusFormat);
        this.testSamples = SequenceLabelerTrainer.getSequenceStream(this.testData, this.clearEvaluationFeatures, this.corpusFormat);
        this.sequenceCodec = Flags.getSequenceCodec(params);
        if (params.getSettings().get("Types") != null) {
            String netypes = params.getSettings().get("Types");
            String[] neTypes = netypes.split(",");
            this.trainSamples = new SequenceLabelSampleTypeFilter(neTypes, this.trainSamples);
            this.testSamples = new SequenceLabelSampleTypeFilter(neTypes, this.testSamples);
        }
        this.createSequenceLabelerFactory(params);
    }

    public void createSequenceLabelerFactory(TrainingParameters params) throws IOException {
        String seqCodec = this.getSequenceCodec();
        SequenceLabelerCodec<String> sequenceCodec = SequenceLabelerFactory.instantiateSequenceCodec(seqCodec);
        String featureDescription = XMLFeatureDescriptor.createXMLFeatureDescriptor(params);
        System.err.println(featureDescription);
        byte[] featureGeneratorBytes = featureDescription.getBytes(Charset.forName("UTF-8"));
        Map<String, Object> resources = LoadModelResources.loadSequenceResources(params);
        this.setSequenceLabelerFactory(SequenceLabelerFactory.create(SequenceLabelerFactory.class.getName(), featureGeneratorBytes, resources, sequenceCodec));
    }

    public final SequenceLabelerModel train(TrainingParameters params) {
        if (this.getSequenceLabelerFactory() == null) {
            throw new IllegalStateException("The SequenceLabelerFactory must be instantiated!!");
        }
        SequenceLabelerModel trainedModel = null;
        try {
            trainedModel = SequenceLabelerME.train(this.lang, this.trainSamples, params, this.nameClassifierFactory);
            SequenceLabelerME seqLabeler = new SequenceLabelerME(trainedModel);
            this.trainingEvaluate(seqLabeler);
        }
        catch (IOException e) {
            System.err.println("IO error while loading traing and test sets!");
            e.printStackTrace();
            System.exit(1);
        }
        return trainedModel;
    }

    private void trainingEvaluate(SequenceLabelerME sequenceLabeler) {
        if (this.corpusFormat.equalsIgnoreCase("lemmatizer") || this.corpusFormat.equalsIgnoreCase("tabulated")) {
            SequenceLabelerEvaluator evaluator = new SequenceLabelerEvaluator(this.trainSamples, this.corpusFormat, sequenceLabeler, new SequenceLabelerEvaluationMonitor[0]);
            try {
                evaluator.evaluate(this.testSamples);
            }
            catch (IOException e) {
                e.printStackTrace();
            }
            System.out.println();
            System.out.println("Word Accuracy: " + evaluator.getWordAccuracy());
            System.out.println("Sentence Accuracy: " + evaluator.getSentenceAccuracy());
        } else {
            SequenceLabelerEvaluator evaluator = new SequenceLabelerEvaluator(this.corpusFormat, sequenceLabeler, new SequenceLabelerEvaluationMonitor[0]);
            try {
                evaluator.evaluate(this.testSamples);
            }
            catch (IOException e) {
                e.printStackTrace();
            }
            System.out.println("Final Result: \n" + evaluator.getFMeasure());
        }
    }

    public static ObjectStream<SequenceLabelSample> getSequenceStream(String inputData, String clearFeatures, String aCorpusFormat) throws IOException {
        ObjectStream<SequenceLabelSample> samples = null;
        if (aCorpusFormat.equalsIgnoreCase("conll03")) {
            ObjectStream<String> nameStream = IOUtils.readFileIntoMarkableStreamFactory(inputData);
            samples = new CoNLL03Format(clearFeatures, nameStream);
        } else if (aCorpusFormat.equalsIgnoreCase("conll02")) {
            ObjectStream<String> nameStream = IOUtils.readFileIntoMarkableStreamFactory(inputData);
            samples = new CoNLL02Format(clearFeatures, nameStream);
        } else if (aCorpusFormat.equalsIgnoreCase("tabulated")) {
            ObjectStream<String> nameStream = IOUtils.readFileIntoMarkableStreamFactory(inputData);
            samples = new TabulatedFormat(clearFeatures, nameStream);
        } else if (aCorpusFormat.equalsIgnoreCase("lemmatizer")) {
            ObjectStream<String> seqStream = IOUtils.readFileIntoMarkableStreamFactory(inputData);
            samples = new LemmatizerFormat(clearFeatures, seqStream);
        } else {
            System.err.println("Test set corpus format not valid!!");
            System.exit(1);
        }
        return samples;
    }

    public final SequenceLabelerFactory getSequenceLabelerFactory() {
        return this.nameClassifierFactory;
    }

    public final SequenceLabelerFactory setSequenceLabelerFactory(SequenceLabelerFactory tokenNameFinderFactory) {
        this.nameClassifierFactory = tokenNameFinderFactory;
        return this.nameClassifierFactory;
    }

    public final String getSequenceCodec() {
        String seqCodec = null;
        if ("BIO".equals(this.sequenceCodec)) {
            seqCodec = BioCodec.class.getName();
        } else if ("BILOU".equals(this.sequenceCodec)) {
            seqCodec = BilouCodec.class.getName();
        }
        return seqCodec;
    }

    public final void setSequenceCodec(String aSeqCodec) {
        this.sequenceCodec = aSeqCodec;
    }
}

