/*
 * Decompiled with CFR 0.152.
 */
package eus.ixa.ixa.pipe.ml;

import eus.ixa.ixa.pipe.ml.features.XMLFeatureDescriptor;
import eus.ixa.ixa.pipe.ml.parse.AncoraHeadRules;
import eus.ixa.ixa.pipe.ml.parse.HeadRules;
import eus.ixa.ixa.pipe.ml.parse.Parse;
import eus.ixa.ixa.pipe.ml.parse.ParseSampleStream;
import eus.ixa.ixa.pipe.ml.parse.ParserEvaluationMonitor;
import eus.ixa.ixa.pipe.ml.parse.ParserEvaluator;
import eus.ixa.ixa.pipe.ml.parse.ParserFactory;
import eus.ixa.ixa.pipe.ml.parse.ParserModel;
import eus.ixa.ixa.pipe.ml.parse.PennTreebankHeadRules;
import eus.ixa.ixa.pipe.ml.parse.ShiftReduceParser;
import eus.ixa.ixa.pipe.ml.resources.LoadModelResources;
import eus.ixa.ixa.pipe.ml.sequence.BilouCodec;
import eus.ixa.ixa.pipe.ml.sequence.BioCodec;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerCodec;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerFactory;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerModel;
import eus.ixa.ixa.pipe.ml.utils.Flags;
import eus.ixa.ixa.pipe.ml.utils.IOUtils;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.Map;
import opennlp.tools.cmdline.TerminateToolException;
import opennlp.tools.dictionary.Dictionary;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;

public class ShiftReduceParserTrainer {
    private final String lang;
    private final String trainData;
    private final String testData;
    private final ObjectStream<Parse> trainSamples;
    private final ObjectStream<Parse> testSamples;
    private final HeadRules rules;
    private ParserFactory parserFactory;
    private String sequenceCodec;
    private SequenceLabelerFactory taggerFactory;
    private SequenceLabelerFactory chunkerFactory;

    public ShiftReduceParserTrainer(TrainingParameters params, TrainingParameters taggerParams, TrainingParameters chunkerParams) throws IOException {
        this.lang = Flags.getLanguage(params);
        this.trainData = (String)params.getSettings().get("TrainSet");
        this.testData = (String)params.getSettings().get("TestSet");
        this.trainSamples = ShiftReduceParserTrainer.getParseStream(this.trainData);
        this.testSamples = ShiftReduceParserTrainer.getParseStream(this.testData);
        this.rules = ShiftReduceParserTrainer.getHeadRules(params);
        this.createParserFactory(params);
        this.setTaggerFactory(this.createSequenceLabelerFactory(taggerParams));
        this.setChunkerFactory(this.createSequenceLabelerFactory(chunkerParams));
    }

    public ShiftReduceParserTrainer(TrainingParameters params, TrainingParameters chunkerParams) throws IOException {
        this.lang = Flags.getLanguage(params);
        this.trainData = (String)params.getSettings().get("TrainSet");
        this.testData = (String)params.getSettings().get("TestSet");
        this.trainSamples = ShiftReduceParserTrainer.getParseStream(this.trainData);
        this.testSamples = ShiftReduceParserTrainer.getParseStream(this.testData);
        this.rules = ShiftReduceParserTrainer.getHeadRules(params);
        this.createParserFactory(params);
        this.setChunkerFactory(this.createSequenceLabelerFactory(chunkerParams));
    }

    public void createParserFactory(TrainingParameters params) throws IOException {
        Dictionary autoDict = ShiftReduceParser.buildDictionary(this.trainSamples, this.rules, params);
        Map<String, Object> resources = LoadModelResources.loadParseResources(params);
        this.setParserFactory(ParserFactory.create(ParserFactory.class.getName(), autoDict, resources));
    }

    public SequenceLabelerFactory createSequenceLabelerFactory(TrainingParameters params) throws IOException {
        String seqCodec = this.getSequenceCodec();
        SequenceLabelerCodec<String> sequenceCodec = SequenceLabelerFactory.instantiateSequenceCodec(seqCodec);
        String featureDescription = XMLFeatureDescriptor.createXMLFeatureDescriptor(params);
        System.err.println(featureDescription);
        byte[] featureGeneratorBytes = featureDescription.getBytes(Charset.forName("UTF-8"));
        Map<String, Object> resources = LoadModelResources.loadSequenceResources(params);
        return SequenceLabelerFactory.create(SequenceLabelerFactory.class.getName(), featureGeneratorBytes, resources, sequenceCodec);
    }

    public final ParserModel train(TrainingParameters params, TrainingParameters taggerParams, TrainingParameters chunkerParams) {
        if (this.getParserFactory() == null) {
            throw new IllegalStateException("The ParserFactory must be instantiated!!");
        }
        if (this.getTaggerFactory() == null) {
            throw new IllegalStateException("The TaggerFactory must be instantiated!");
        }
        ParserModel trainedModel = null;
        ParserEvaluator parserEvaluator = null;
        try {
            trainedModel = ShiftReduceParser.train(this.lang, this.trainSamples, this.rules, params, this.parserFactory, taggerParams, this.taggerFactory, chunkerParams, this.chunkerFactory);
            ShiftReduceParser parser = new ShiftReduceParser(trainedModel);
            parserEvaluator = new ParserEvaluator(parser, new ParserEvaluationMonitor[0]);
            parserEvaluator.evaluate(this.testSamples);
        }
        catch (IOException e) {
            System.err.println("IO error while loading training and test sets!");
            e.printStackTrace();
            System.exit(1);
        }
        System.out.println("Final Result: \n" + parserEvaluator.getFMeasure());
        return trainedModel;
    }

    public final ParserModel train(TrainingParameters params, InputStream taggerModel, TrainingParameters chunkerParams) {
        if (this.getParserFactory() == null) {
            throw new IllegalStateException("The ParserFactory must be instantiated!!");
        }
        SequenceLabelerModel posModel = null;
        try {
            posModel = new SequenceLabelerModel(taggerModel);
        }
        catch (IOException e1) {
            e1.printStackTrace();
        }
        ParserModel trainedModel = null;
        ParserEvaluator parserEvaluator = null;
        try {
            trainedModel = ShiftReduceParser.train(this.lang, this.trainSamples, this.rules, params, this.parserFactory, posModel, chunkerParams, this.chunkerFactory);
            ShiftReduceParser parser = new ShiftReduceParser(trainedModel);
            parserEvaluator = new ParserEvaluator(parser, new ParserEvaluationMonitor[0]);
            parserEvaluator.evaluate(this.testSamples);
        }
        catch (IOException e) {
            System.err.println("IO error while loading training and test sets!");
            e.printStackTrace();
            System.exit(1);
        }
        System.out.println("Final Result: \n" + parserEvaluator.getFMeasure());
        return trainedModel;
    }

    public static ObjectStream<Parse> getParseStream(String inputData) throws IOException {
        ObjectStream<String> parseStream = IOUtils.readFileIntoMarkableStreamFactory(inputData);
        ParseSampleStream samples = new ParseSampleStream(parseStream);
        return samples;
    }

    public static HeadRules getHeadRules(TrainingParameters params) throws IOException {
        Object headRulesSerializer = null;
        if (Flags.getLanguage(params).equalsIgnoreCase("en")) {
            headRulesSerializer = new PennTreebankHeadRules.PennTreebankHeadRulesSerializer();
        } else if (Flags.getLanguage(params).equalsIgnoreCase("es")) {
            headRulesSerializer = new AncoraHeadRules.AncoraHeadRulesSerializer();
        } else {
            System.err.println("HeadRules not suported for language " + Flags.getLanguage(params) + "!!");
        }
        Object headRulesObject = headRulesSerializer.create(new FileInputStream(Flags.getHeadRulesFile(params)));
        if (headRulesObject instanceof HeadRules) {
            return (HeadRules)headRulesObject;
        }
        throw new TerminateToolException(-1, "HeadRules Artifact Serializer must create an object of type HeadRules!");
    }

    public final SequenceLabelerFactory getTaggerFactory() {
        return this.taggerFactory;
    }

    public final SequenceLabelerFactory setTaggerFactory(SequenceLabelerFactory tokenNameFinderFactory) {
        this.taggerFactory = tokenNameFinderFactory;
        return this.taggerFactory;
    }

    public final SequenceLabelerFactory getChunkerFactory() {
        return this.chunkerFactory;
    }

    public final SequenceLabelerFactory setChunkerFactory(SequenceLabelerFactory tokenNameFinderFactory) {
        this.chunkerFactory = tokenNameFinderFactory;
        return this.chunkerFactory;
    }

    public final ParserFactory getParserFactory() {
        return this.parserFactory;
    }

    public final ParserFactory setParserFactory(ParserFactory parserFactory) {
        this.parserFactory = parserFactory;
        return parserFactory;
    }

    public final String getSequenceCodec() {
        String seqCodec = null;
        if ("BIO".equals(this.sequenceCodec)) {
            seqCodec = BioCodec.class.getName();
        } else if ("BILOU".equals(this.sequenceCodec)) {
            seqCodec = BilouCodec.class.getName();
        }
        return seqCodec;
    }

    public final void setSequenceCodec(String aSeqCodec) {
        this.sequenceCodec = aSeqCodec;
    }
}

