/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.share.upenn;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.pipe.CharSequence2TokenSequence;
import cc.mallet.pipe.CharSequenceArray2TokenSequence;
import cc.mallet.pipe.FeatureSequence2FeatureVector;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2Label;
import cc.mallet.pipe.TokenSequence2FeatureSequence;
import cc.mallet.pipe.iterator.ArrayDataAndTargetIterator;
import cc.mallet.pipe.iterator.ArrayIterator;
import cc.mallet.pipe.iterator.LineIterator;
import cc.mallet.pipe.iterator.PipeExtendedIterator;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.Labeling;
import cc.mallet.types.TokenSequence;
import cc.mallet.util.CharSequenceLexer;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.logging.Logger;
import java.util.regex.Pattern;

public class MaxEntShell {
    private static Logger logger = MalletLogger.getLogger(MaxEntShell.class.getName());
    private static final CommandOption.Double gaussianVarianceOption = new CommandOption.Double(MaxEntShell.class, "gaussian-variance", "decimal", true, 1.0, "The gaussian prior variance used for training.", null);
    private static final CommandOption.File trainOption = new CommandOption.File(MaxEntShell.class, "train", "FILENAME", true, null, "Training datafile", null);
    private static final CommandOption.File testOption = new CommandOption.File(MaxEntShell.class, "test", "filename", true, null, "Test datafile", null);
    private static final CommandOption.File classifyOption = new CommandOption.File(MaxEntShell.class, "classify", "filename", true, null, "Datafile to classify", null);
    private static final CommandOption.File modelOption = new CommandOption.File(MaxEntShell.class, "model", "filename", true, null, "Model file", null);
    private static final CommandOption.String encodingOption = new CommandOption.String(MaxEntShell.class, "encoding", "character-encoding-name", true, null, "Input character encoding", null);
    private static final CommandOption.Boolean internalTestOption = new CommandOption.Boolean(MaxEntShell.class, "internal-test", "true|false", true, false, "Run internal tests", null);
    private static final CommandOption.List commandOptions = new CommandOption.List("Training, testing and running a generic tagger.", new CommandOption[]{gaussianVarianceOption, trainOption, testOption, modelOption, classifyOption, encodingOption, internalTestOption});
    private static final String[][] internalData = new String[][]{{"a", "b"}, {"b", "c"}, {"a", "c"}};
    private static final String[] internalTargets = new String[]{"yes", "no", "no"};
    private static final String[] internalInstance = new String[]{"a", "b", "c"};

    private MaxEntShell() {
    }

    public static Classifier train(String[][] features, String[] labels, double var, File save) throws IOException {
        return MaxEntShell.train(new PipeExtendedIterator(new ArrayDataAndTargetIterator((Object[])features, labels), new CharSequenceArray2TokenSequence()), var, save);
    }

    public static Classifier train(Iterator<Instance> data, double var, File save) throws IOException {
        Alphabet features = new Alphabet();
        LabelAlphabet labels = new LabelAlphabet();
        SerialPipes instancePipe = new SerialPipes(new Pipe[]{new Target2Label(labels), new TokenSequence2FeatureSequence(features), new FeatureSequence2FeatureVector()});
        InstanceList trainingList = new InstanceList(instancePipe);
        trainingList.addThruPipe(data);
        logger.info("# features = " + features.size());
        logger.info("# labels = " + labels.size());
        logger.info("# training instances = " + trainingList.size());
        MaxEntTrainer trainer = new MaxEntTrainer(var);
        Object classifier = trainer.train(trainingList);
        logger.info("The training accuracy is " + ((Classifier)classifier).getAccuracy(trainingList));
        features.stopGrowth();
        if (save != null) {
            ObjectOutputStream s = new ObjectOutputStream(new FileOutputStream(save));
            s.writeObject(classifier);
            s.close();
        }
        return classifier;
    }

    public static double test(Classifier classifier, String[][] features, String[] labels) {
        return MaxEntShell.test(classifier, new PipeExtendedIterator(new ArrayDataAndTargetIterator((Object[])features, labels), new CharSequenceArray2TokenSequence()));
    }

    public static double test(Classifier classifier, Iterator<Instance> data) {
        InstanceList testList = new InstanceList(classifier.getInstancePipe());
        testList.addThruPipe(data);
        logger.info("# test instances = " + testList.size());
        double accuracy = classifier.getAccuracy(testList);
        return accuracy;
    }

    public static Classification classify(Classifier classifier, String[] features) {
        return classifier.classify(new Instance(new TokenSequence(features), null, null, null));
    }

    public static Classification[] classify(Classifier classifier, String[][] features) {
        return MaxEntShell.classify(classifier, new PipeExtendedIterator(new ArrayIterator((Object[])features), new CharSequenceArray2TokenSequence()));
    }

    public static Classification[] classify(Classifier classifier, Iterator<Instance> data) {
        InstanceList unlabeledList = new InstanceList(classifier.getInstancePipe());
        unlabeledList.addThruPipe(data);
        logger.info("# unlabeled instances = " + unlabeledList.size());
        ArrayList<Classification> classifications = classifier.classify(unlabeledList);
        return classifications.toArray(new Classification[0]);
    }

    public static Classifier load(File modelFile) throws IOException, ClassNotFoundException {
        ObjectInputStream s = new ObjectInputStream(new FileInputStream(modelFile));
        Classifier c = (Classifier)s.readObject();
        s.close();
        return c;
    }

    private static void internalTest() throws IOException {
        Classifier classifier = MaxEntShell.train(internalData, internalTargets, 1.0, null);
        System.out.println("Training accuracy = " + MaxEntShell.test(classifier, internalData, internalTargets));
        Classification cl = MaxEntShell.classify(classifier, internalInstance);
        Labeling lab = cl.getLabeling();
        LabelAlphabet labels = lab.getLabelAlphabet();
        for (int c = 0; c < labels.size(); ++c) {
            System.out.print(labels.lookupObject(c) + " " + lab.value(c) + " ");
        }
        System.out.println();
    }

    private static InputStreamReader getReader(File file, String encoding) throws IOException {
        return encoding != null ? new InputStreamReader((InputStream)new FileInputStream(file), encoding) : new FileReader(file);
    }

    public static void main(String[] args) throws Exception {
        Classifier classifier = null;
        CharSequence2TokenSequence preprocess = new CharSequence2TokenSequence(new CharSequenceLexer(CharSequenceLexer.LEX_NONWHITESPACE_TOGETHER));
        InputStreamReader trainingData = null;
        InputStreamReader testData = null;
        Pattern instanceFormat = Pattern.compile("^\\s*(\\S+)\\s*(.*)\\s*$");
        Pattern unlabeledInstanceFormat = Pattern.compile("^\\s*(.*)\\s*$");
        commandOptions.process(args);
        if (MaxEntShell.internalTestOption.value) {
            MaxEntShell.internalTest();
        }
        if (MaxEntShell.trainOption.value != null) {
            trainingData = MaxEntShell.getReader(MaxEntShell.trainOption.value, MaxEntShell.encodingOption.value);
            classifier = MaxEntShell.train(new PipeExtendedIterator(new LineIterator((Reader)trainingData, instanceFormat, 2, 1, -1), preprocess), MaxEntShell.gaussianVarianceOption.value, MaxEntShell.modelOption.value);
        } else if (MaxEntShell.modelOption.value != null) {
            classifier = MaxEntShell.load(MaxEntShell.modelOption.value);
        }
        if (classifier != null) {
            if (MaxEntShell.testOption.value != null) {
                testData = MaxEntShell.getReader(MaxEntShell.testOption.value, MaxEntShell.encodingOption.value);
                System.out.println("The testing accuracy is " + MaxEntShell.test(classifier, new PipeExtendedIterator(new LineIterator((Reader)testData, instanceFormat, 2, 1, -1), preprocess)));
            }
            if (MaxEntShell.classifyOption.value != null) {
                classifier.getInstancePipe().setTargetProcessing(false);
                InputStreamReader unlabeledData = MaxEntShell.getReader(MaxEntShell.classifyOption.value, MaxEntShell.encodingOption.value);
                Classification[] cl = MaxEntShell.classify(classifier, new PipeExtendedIterator(new LineIterator((Reader)unlabeledData, unlabeledInstanceFormat, 1, -1, -1), preprocess));
                for (int i = 0; i < cl.length; ++i) {
                    Labeling lab = cl[i].getLabeling();
                    LabelAlphabet labels = lab.getLabelAlphabet();
                    for (int c = 0; c < labels.size(); ++c) {
                        System.out.print(labels.lookupObject(c) + " " + lab.value(c) + " ");
                    }
                    System.out.println();
                }
            }
        }
    }
}

