/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised.tui;

import cc.mallet.fst.CRF;
import cc.mallet.fst.MaxLatticeDefault;
import cc.mallet.fst.MultiSegmentationEvaluator;
import cc.mallet.fst.NoopTransducerTrainer;
import cc.mallet.fst.SimpleTagger;
import cc.mallet.fst.TokenAccuracyEvaluator;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerEvaluator;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.fst.semi_supervised.CRFTrainerByGE;
import cc.mallet.fst.semi_supervised.FSTConstraintUtil;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.fst.semi_supervised.constraints.OneLabelKLGEConstraints;
import cc.mallet.fst.semi_supervised.constraints.OneLabelL2RangeGEConstraints;
import cc.mallet.fst.semi_supervised.pr.CRFTrainerByPR;
import cc.mallet.fst.semi_supervised.pr.constraints.OneLabelL2IndPRConstraints;
import cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.iterator.LineGroupIterator;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Random;
import java.util.logging.Logger;
import java.util.regex.Pattern;

public class SimpleTaggerWithConstraints {
    private static Logger logger = MalletLogger.getLogger(SimpleTaggerWithConstraints.class.getName());
    private static final CommandOption.Double gaussianVarianceOption = new CommandOption.Double(SimpleTaggerWithConstraints.class, "gaussian-variance", "DECIMAL", true, 10.0, "The gaussian prior variance used for training.", null);
    private static final CommandOption.Double qGaussianVarianceOption = new CommandOption.Double(SimpleTaggerWithConstraints.class, "q-gaussian-variance", "DECIMAL", true, 10.0, "The gaussian prior variance used in the E-step for PR training.", null);
    private static final CommandOption.Boolean trainOption = new CommandOption.Boolean(SimpleTaggerWithConstraints.class, "train", "true|false", true, false, "Whether to train", null);
    private static final CommandOption.String testOption = new CommandOption.String(SimpleTaggerWithConstraints.class, "test", "lab or seg=start-1.continue-1,...,start-n.continue-n", true, null, "Test measuring labeling or segmentation (start-i, continue-i) accuracy", null);
    private static final CommandOption.File modelOption = new CommandOption.File(SimpleTaggerWithConstraints.class, "model-file", "FILENAME", true, null, "The filename for reading (train/run) or saving (train) the model.", null);
    private static final CommandOption.Double trainingFractionOption = new CommandOption.Double(SimpleTaggerWithConstraints.class, "training-proportion", "DECIMAL", true, 0.5, "Fraction of data to use for training in a random split.", null);
    private static final CommandOption.Integer randomSeedOption = new CommandOption.Integer(SimpleTaggerWithConstraints.class, "random-seed", "INTEGER", true, 0, "The random seed for randomly selecting a proportion of the instance list for training", null);
    private static final CommandOption.IntegerArray ordersOption = new CommandOption.IntegerArray(SimpleTaggerWithConstraints.class, "orders", "COMMA-SEP-DECIMALS", true, new int[]{1}, "List of label Markov orders (main and backoff) ", null);
    private static final CommandOption.String forbiddenOption = new CommandOption.String(SimpleTaggerWithConstraints.class, "forbidden", "REGEXP", true, "\\s", "label1,label2 transition forbidden if it matches this", null);
    private static final CommandOption.String allowedOption = new CommandOption.String(SimpleTaggerWithConstraints.class, "allowed", "REGEXP", true, ".*", "label1,label2 transition allowed only if it matches this", null);
    private static final CommandOption.String defaultOption = new CommandOption.String(SimpleTaggerWithConstraints.class, "default-label", "STRING", true, "O", "Label for initial context and uninteresting tokens", null);
    private static final CommandOption.String penaltyOption = new CommandOption.String(SimpleTaggerWithConstraints.class, "penalty", "kl|l2", true, "l2", "penalty function for constraint violation.", null);
    private static final CommandOption.String learningOption = new CommandOption.String(SimpleTaggerWithConstraints.class, "learning", "ge|pr", true, "ge", "Learning method to use.", null);
    private static final CommandOption.Integer iterationsOption = new CommandOption.Integer(SimpleTaggerWithConstraints.class, "iterations", "INTEGER", true, 500, "Number of training iterations", null);
    private static final CommandOption.Boolean viterbiOutputOption = new CommandOption.Boolean(SimpleTaggerWithConstraints.class, "viterbi-output", "true|false", true, false, "Print Viterbi periodically during training", null);
    private static final CommandOption.Boolean continueTrainingOption = new CommandOption.Boolean(SimpleTaggerWithConstraints.class, "continue-training", "true|false", false, false, "Continue training from model specified by --model-file", null);
    private static final CommandOption.Integer nBestOption = new CommandOption.Integer(SimpleTaggerWithConstraints.class, "n-best", "INTEGER", true, 1, "How many answers to output", null);
    private static final CommandOption.Integer cacheSizeOption = new CommandOption.Integer(SimpleTaggerWithConstraints.class, "cache-size", "INTEGER", true, 100000, "How much state information to memoize in n-best decoding", null);
    private static final CommandOption.Boolean includeInputOption = new CommandOption.Boolean(SimpleTaggerWithConstraints.class, "include-input", "true|false", true, false, "Whether to include the input features when printing decoding output", null);
    private static final CommandOption.Integer numThreads = new CommandOption.Integer(SimpleTaggerWithConstraints.class, "threads", "INTEGER", true, 1, "Number of threads to use for CRF training.", null);
    private static final CommandOption.Integer numResets = new CommandOption.Integer(SimpleTaggerWithConstraints.class, "resets", "INTEGER", true, 4, "Number of L-BFGS resets to use.", null);
    private static final CommandOption.List commandOptions = new CommandOption.List("Training, testing and running a generic tagger.", new CommandOption[]{gaussianVarianceOption, qGaussianVarianceOption, trainOption, iterationsOption, testOption, trainingFractionOption, modelOption, randomSeedOption, ordersOption, forbiddenOption, allowedOption, defaultOption, viterbiOutputOption, penaltyOption, learningOption, continueTrainingOption, nBestOption, cacheSizeOption, includeInputOption, numThreads, numResets});

    private SimpleTaggerWithConstraints() {
    }

    public static CRF trainGE(InstanceList training, InstanceList testing, ArrayList<GEConstraint> constraints, CRF crf, TransducerEvaluator eval, int iterations, double var, int resets) {
        logger.info("Training on " + training.size() + " instances");
        if (testing != null) {
            logger.info("Testing on " + testing.size() + " instances");
        }
        assert (SimpleTaggerWithConstraints.numThreads.value > 0);
        CRFTrainerByGE trainer = new CRFTrainerByGE(crf, constraints, SimpleTaggerWithConstraints.numThreads.value);
        if (eval != null) {
            trainer.addEvaluator(eval);
        }
        trainer.setGaussianPriorVariance(var);
        trainer.setNumResets(resets);
        trainer.train(training, iterations);
        return crf;
    }

    public static CRF trainPR(InstanceList training, InstanceList testing, ArrayList<PRConstraint> constraints, CRF crf, TransducerEvaluator eval, int iterations, double var) {
        logger.info("Training on " + training.size() + " instances");
        if (testing != null) {
            logger.info("Testing on " + testing.size() + " instances");
        }
        assert (SimpleTaggerWithConstraints.numThreads.value > 0);
        CRFTrainerByPR trainer = new CRFTrainerByPR(crf, constraints, SimpleTaggerWithConstraints.numThreads.value);
        trainer.addEvaluator(eval);
        trainer.setPGaussianPriorVariance(var);
        trainer.train(training, iterations, iterations);
        return crf;
    }

    public static CRF getCRF(InstanceList training, int[] orders, String defaultLabel, String forbidden, String allowed, boolean connected) {
        Pattern forbiddenPat = Pattern.compile(forbidden);
        Pattern allowedPat = Pattern.compile(allowed);
        CRF crf = new CRF(training.getPipe(), null);
        String startName = crf.addOrderNStates(training, orders, null, defaultLabel, forbiddenPat, allowedPat, connected);
        int i = 0;
        while (i < crf.numStates()) {
            crf.getState(i).setInitialWeight(Double.NEGATIVE_INFINITY);
            ++i;
        }
        crf.getState(startName).setInitialWeight(0.0);
        crf.setWeightsDimensionDensely();
        return crf;
    }

    public static void test(TransducerTrainer tt, TransducerEvaluator eval, InstanceList testing) {
        eval.evaluateInstanceList(tt, testing, "Testing");
    }

    public static Sequence[] apply(Transducer model, Sequence input, int k) {
        Sequence[] answers;
        if (k == 1) {
            answers = new Sequence[]{model.transduce(input)};
        } else {
            MaxLatticeDefault lattice = new MaxLatticeDefault(model, input, null, cacheSizeOption.value());
            answers = lattice.bestOutputSequences(k).toArray(new Sequence[0]);
        }
        return answers;
    }

    public static void main(String[] args) throws Exception {
        String[] pair;
        ObjectInputStream s;
        long startTime = System.currentTimeMillis();
        FileReader trainingFile = null;
        FileReader testFile = null;
        FileReader constraintsFile = null;
        InstanceList trainingData = null;
        InstanceList testData = null;
        int restArgs = commandOptions.processOptions(args);
        if (restArgs == args.length) {
            commandOptions.printUsage(true);
            throw new IllegalArgumentException("Missing data file(s)");
        }
        if (SimpleTaggerWithConstraints.trainOption.value) {
            trainingFile = new FileReader(new File(args[restArgs]));
            if (SimpleTaggerWithConstraints.testOption.value != null) {
                testFile = new FileReader(new File(args[restArgs + 1]));
                constraintsFile = new FileReader(new File(args[restArgs + 2]));
            } else {
                constraintsFile = new FileReader(new File(args[restArgs + 1]));
            }
        } else {
            testFile = new FileReader(new File(args[restArgs]));
        }
        Pipe p = null;
        CRF crf = null;
        TransducerEvaluator eval = null;
        if (SimpleTaggerWithConstraints.continueTrainingOption.value || !SimpleTaggerWithConstraints.trainOption.value) {
            if (SimpleTaggerWithConstraints.modelOption.value == null) {
                commandOptions.printUsage(true);
                throw new IllegalArgumentException("Missing model file option");
            }
            s = new ObjectInputStream(new FileInputStream(SimpleTaggerWithConstraints.modelOption.value));
            crf = (CRF)s.readObject();
            s.close();
            p = crf.getInputPipe();
        } else {
            p = new SimpleTagger.SimpleTaggerSentence2FeatureVectorSequence();
            p.getTargetAlphabet().lookupIndex(SimpleTaggerWithConstraints.defaultOption.value);
        }
        if (SimpleTaggerWithConstraints.trainOption.value) {
            p.setTargetProcessing(true);
            trainingData = new InstanceList(p);
            trainingData.addThruPipe(new LineGroupIterator(trainingFile, Pattern.compile("^\\s*$"), true));
            logger.info("Number of features in training data: " + p.getDataAlphabet().size());
            if (SimpleTaggerWithConstraints.testOption.value != null) {
                if (testFile != null) {
                    testData = new InstanceList(p);
                    testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true));
                } else {
                    Random r = new Random(SimpleTaggerWithConstraints.randomSeedOption.value);
                    InstanceList[] trainingLists = trainingData.split(r, new double[]{SimpleTaggerWithConstraints.trainingFractionOption.value, 1.0 - SimpleTaggerWithConstraints.trainingFractionOption.value});
                    trainingData = trainingLists[0];
                    testData = trainingLists[1];
                }
            }
        } else if (SimpleTaggerWithConstraints.testOption.value != null) {
            p.setTargetProcessing(true);
            testData = new InstanceList(p);
            testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true));
        } else {
            p.setTargetProcessing(false);
            testData = new InstanceList(p);
            testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true));
        }
        logger.info("Number of predicates: " + p.getDataAlphabet().size());
        if (SimpleTaggerWithConstraints.testOption.value != null) {
            if (SimpleTaggerWithConstraints.testOption.value.startsWith("lab")) {
                eval = new TokenAccuracyEvaluator(new InstanceList[]{trainingData, testData}, new String[]{"Training", "Testing"});
            } else if (SimpleTaggerWithConstraints.testOption.value.startsWith("seg=")) {
                String[] pairs = SimpleTaggerWithConstraints.testOption.value.substring(4).split(",");
                if (pairs.length < 1) {
                    commandOptions.printUsage(true);
                    throw new IllegalArgumentException("Missing segment start/continue labels: " + SimpleTaggerWithConstraints.testOption.value);
                }
                Object[] startTags = new String[pairs.length];
                Object[] continueTags = new String[pairs.length];
                int i = 0;
                while (i < pairs.length) {
                    pair = pairs[i].split("\\.");
                    if (pair.length != 2) {
                        commandOptions.printUsage(true);
                        throw new IllegalArgumentException("Incorrectly-specified segment start and end labels: " + pairs[i]);
                    }
                    startTags[i] = pair[0];
                    continueTags[i] = pair[1];
                    ++i;
                }
                eval = new MultiSegmentationEvaluator(new InstanceList[]{trainingData, testData}, new String[]{"Training", "Testing"}, startTags, continueTags);
            } else {
                commandOptions.printUsage(true);
                throw new IllegalArgumentException("Invalid test option: " + SimpleTaggerWithConstraints.testOption.value);
            }
        }
        if (p.isTargetProcessing()) {
            Alphabet targets = p.getTargetAlphabet();
            StringBuffer buf = new StringBuffer("Labels:");
            int i = 0;
            while (i < targets.size()) {
                buf.append(" ").append(targets.lookupObject(i).toString());
                ++i;
            }
            logger.info(buf.toString());
        }
        if (SimpleTaggerWithConstraints.trainOption.value) {
            int fi;
            ArrayList<GEConstraint> constraintsList;
            if (crf == null) {
                crf = SimpleTaggerWithConstraints.getCRF(trainingData, SimpleTaggerWithConstraints.ordersOption.value, SimpleTaggerWithConstraints.defaultOption.value, SimpleTaggerWithConstraints.forbiddenOption.value, SimpleTaggerWithConstraints.allowedOption.value, true);
            }
            HashMap<Integer, double[][]> constraints = FSTConstraintUtil.loadGEConstraints(constraintsFile, trainingData);
            if (SimpleTaggerWithConstraints.learningOption.value.equalsIgnoreCase("ge")) {
                GEConstraint geConstraints;
                constraintsList = new ArrayList<GEConstraint>();
                if (SimpleTaggerWithConstraints.penaltyOption.value.equalsIgnoreCase("kl")) {
                    geConstraints = new OneLabelKLGEConstraints();
                    pair = constraints.keySet().iterator();
                    while (pair.hasNext()) {
                        fi = pair.next();
                        double[][] dist = constraints.get(fi);
                        boolean allSame = true;
                        double sum = 0.0;
                        double[] prob = new double[dist.length];
                        int li = 0;
                        while (li < dist.length) {
                            prob[li] = dist[li][0];
                            if (!Maths.almostEquals(dist[li][0], dist[li][1])) {
                                allSame = false;
                                break;
                            }
                            if (Double.isInfinite(prob[li])) {
                                prob[li] = 0.0;
                            }
                            sum += prob[li];
                            ++li;
                        }
                        if (!allSame) {
                            throw new RuntimeException("A KL divergence penalty cannot be used with target ranges!");
                        }
                        if (!Maths.almostEquals(sum, 1.0)) {
                            throw new RuntimeException("Targets must sum to 1 when using a KL divergence penalty!");
                        }
                        ((OneLabelKLGEConstraints)geConstraints).addConstraint(fi, prob, 1.0);
                    }
                    constraintsList.add(geConstraints);
                } else if (SimpleTaggerWithConstraints.penaltyOption.value.equalsIgnoreCase("l2")) {
                    geConstraints = new OneLabelL2RangeGEConstraints();
                    pair = constraints.keySet().iterator();
                    while (pair.hasNext()) {
                        fi = (Integer)pair.next();
                        double[][] dist = constraints.get(fi);
                        int li = 0;
                        while (li < dist.length) {
                            if (!Double.isInfinite(dist[li][0])) {
                                ((OneLabelL2RangeGEConstraints)geConstraints).addConstraint(fi, li, dist[li][0], dist[li][1], 1.0);
                            }
                            ++li;
                        }
                    }
                    constraintsList.add(geConstraints);
                } else {
                    throw new RuntimeException("Unknown penalty " + SimpleTaggerWithConstraints.penaltyOption.value);
                }
                crf = SimpleTaggerWithConstraints.trainGE(trainingData, testData, constraintsList, crf, eval, SimpleTaggerWithConstraints.iterationsOption.value, SimpleTaggerWithConstraints.gaussianVarianceOption.value, SimpleTaggerWithConstraints.numResets.value);
            } else if (SimpleTaggerWithConstraints.learningOption.value.equalsIgnoreCase("pr")) {
                OneLabelL2IndPRConstraints prConstraints;
                constraintsList = new ArrayList();
                if (SimpleTaggerWithConstraints.penaltyOption.value.equalsIgnoreCase("l2")) {
                    prConstraints = new OneLabelL2IndPRConstraints(true);
                    pair = constraints.keySet().iterator();
                    while (pair.hasNext()) {
                        fi = (Integer)pair.next();
                        double[][] dist = constraints.get(fi);
                        int li = 0;
                        while (li < dist.length) {
                            if (!Double.isInfinite(dist[li][0]) && !Maths.almostEquals(dist[li][0], dist[li][1])) {
                                throw new RuntimeException("Support for range constraints in PR in development. " + SimpleTaggerWithConstraints.penaltyOption.value);
                            }
                            if (!Double.isInfinite(dist[li][0])) {
                                prConstraints.addConstraint(fi, li, dist[li][0], SimpleTaggerWithConstraints.qGaussianVarianceOption.value);
                            }
                            ++li;
                        }
                    }
                } else {
                    if (SimpleTaggerWithConstraints.penaltyOption.value.equalsIgnoreCase("kl")) {
                        throw new RuntimeException("KL divergence not supported for PR.");
                    }
                    throw new RuntimeException("Unknown penalty " + SimpleTaggerWithConstraints.penaltyOption.value);
                }
                constraintsList.add((GEConstraint)((Object)prConstraints));
                crf = SimpleTaggerWithConstraints.trainPR(trainingData, testData, constraintsList, crf, eval, SimpleTaggerWithConstraints.iterationsOption.value, SimpleTaggerWithConstraints.gaussianVarianceOption.value);
            } else {
                throw new RuntimeException("Unknown learning algorithm " + SimpleTaggerWithConstraints.learningOption.value);
            }
            if (SimpleTaggerWithConstraints.modelOption.value != null) {
                ObjectOutputStream s2 = new ObjectOutputStream(new FileOutputStream(SimpleTaggerWithConstraints.modelOption.value));
                s2.writeObject(crf);
                s2.close();
            }
        } else {
            if (crf == null) {
                if (SimpleTaggerWithConstraints.modelOption.value == null) {
                    commandOptions.printUsage(true);
                    throw new IllegalArgumentException("Missing model file option");
                }
                s = new ObjectInputStream(new FileInputStream(SimpleTaggerWithConstraints.modelOption.value));
                crf = (CRF)s.readObject();
                s.close();
            }
            if (eval != null) {
                SimpleTaggerWithConstraints.test(new NoopTransducerTrainer(crf), eval, testData);
            } else {
                boolean includeInput = includeInputOption.value();
                int i = 0;
                while (i < testData.size()) {
                    Sequence input = (Sequence)((Instance)testData.get(i)).getData();
                    Sequence[] outputs = SimpleTaggerWithConstraints.apply(crf, input, SimpleTaggerWithConstraints.nBestOption.value);
                    int k = outputs.length;
                    boolean error = false;
                    int a = 0;
                    while (a < k) {
                        if (outputs[a].size() != input.size()) {
                            logger.info("Failed to decode input sequence " + i + ", answer " + a);
                            error = true;
                        }
                        ++a;
                    }
                    if (!error) {
                        int j = 0;
                        while (j < input.size()) {
                            StringBuffer buf = new StringBuffer();
                            int a2 = 0;
                            while (a2 < k) {
                                buf.append(outputs[a2].get(j).toString()).append(" ");
                                ++a2;
                            }
                            if (includeInput) {
                                FeatureVector fv = (FeatureVector)input.get(j);
                                buf.append(fv.toString(true));
                            }
                            System.out.println(buf.toString());
                            ++j;
                        }
                        System.out.println();
                    }
                    ++i;
                }
            }
        }
        long time = (System.currentTimeMillis() - startTime) / 1000L;
        System.err.println("took " + time + " seconds");
    }
}

