/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.jcore.ae.jnet.tagger;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFTrainerByLabelLikelihood;
import cc.mallet.fst.Segment;
import cc.mallet.fst.SumLatticeConstrained;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import de.julielab.jcore.ae.jnet.tagger.FeatureConfiguration;
import de.julielab.jcore.ae.jnet.tagger.FeatureGenerator;
import de.julielab.jcore.ae.jnet.tagger.FeatureSubsetModel;
import de.julielab.jcore.ae.jnet.tagger.METrainerDummyPipe;
import de.julielab.jcore.ae.jnet.tagger.Sentence;
import de.julielab.jcore.ae.jnet.tagger.Unit;
import de.julielab.jcore.ae.jnet.utils.IOEvaluation;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Properties;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NETagger {
    private Object model = null;
    private Properties featureConfig = null;
    private boolean trained = false;
    static Logger LOGGER = LoggerFactory.getLogger(NETagger.class);
    private int number_iterations = 0;
    private boolean max_ent = false;
    private Pipe generalPipe = null;
    private Pipe dummyPipe = null;

    public NETagger() {
        Properties defaults = new Properties();
        InputStream defaultFeatureConfigStream = this.getClass().getResourceAsStream("/defaultFeatureConf.conf");
        try {
            LOGGER.debug("loading default configuration");
            defaults.load(defaultFeatureConfigStream);
        }
        catch (FileNotFoundException e) {
            e.printStackTrace();
            LOGGER.error("", (Throwable)e);
        }
        catch (IOException e) {
            e.printStackTrace();
            LOGGER.error("", (Throwable)e);
        }
        this.featureConfig = new Properties(defaults);
    }

    public NETagger(File featureConfigFile) {
        this.featureConfig = new Properties();
        if (!featureConfigFile.isFile()) {
            IllegalStateException e = new IllegalStateException("specified file for feature configuration not found!");
            LOGGER.error("", (Throwable)e);
            throw e;
        }
        try {
            this.featureConfig.load(new FileInputStream(featureConfigFile));
        }
        catch (FileNotFoundException e) {
            e.printStackTrace();
            LOGGER.error("", (Throwable)e);
        }
        catch (IOException e) {
            e.printStackTrace();
            LOGGER.error("", (Throwable)e);
        }
    }

    public boolean isTrained() {
        return this.trained;
    }

    public void train(ArrayList<Sentence> sentences) {
        System.out.println("   * training model... on " + sentences.size() + " sentences");
        FeatureGenerator featureGenerator = new FeatureGenerator();
        InstanceList data = featureGenerator.createFeatureData(sentences, this.featureConfig);
        this.generalPipe = data.getPipe();
        LOGGER.info("  * number of features for training: " + data.getDataAlphabet().size());
        long start = System.currentTimeMillis();
        if (!this.max_ent) {
            this.model = new CRF(data.getPipe(), null);
            ((CRF)this.model).addStatesForBiLabelsConnectedAsIn(data);
            CRFTrainerByLabelLikelihood crfTrainer = new CRFTrainerByLabelLikelihood((CRF)this.model);
            if (this.number_iterations == 0) {
                boolean b = crfTrainer.trainOptimized(data);
                LOGGER.info("JNET training: model converged: " + b);
            } else {
                crfTrainer.train(data, this.number_iterations);
                LOGGER.info("JNET training: with iterations = " + this.number_iterations);
            }
        } else if (this.max_ent) {
            this.dummyPipe = new SerialPipes(new Pipe[]{new METrainerDummyPipe(data.getDataAlphabet(), data.getTargetAlphabet())});
            InstanceList tokenData = FeatureGenerator.convertFeatsforClassifier(this.dummyPipe, data);
            LOGGER.info("train() - now training on " + data.size() + " instances");
            MaxEntTrainer maxEntTrainer = new MaxEntTrainer();
            LOGGER.info("JNET ME training ...");
            MaxEnt me = null;
            me = this.number_iterations == 0 ? maxEntTrainer.train(tokenData) : maxEntTrainer.train(tokenData, this.number_iterations);
            this.model = me;
        }
        long stop = System.currentTimeMillis();
        LOGGER.info("  * learning took (sec): " + (stop - start) / 1000L);
        this.trained = true;
    }

    public void predict(Sentence sentence, boolean showSegmentConfidence) {
        block7: {
            block6: {
                if (!this.trained || this.model == null) {
                    IllegalStateException e = new IllegalStateException("No model available. Train or load trained model first.");
                    LOGGER.error("", (Throwable)e);
                    throw e;
                }
                if (this.max_ent) break block6;
                Instance inst = ((Transducer)this.model).getInputPipe().instanceFrom(new Instance((Object)sentence, (Object)"", (Object)"", (Object)""));
                Sequence input = (Sequence)inst.getData();
                Sequence output = ((Transducer)this.model).transduce(input);
                if (output.size() != sentence.getUnits().size()) {
                    IllegalStateException e = new IllegalStateException("Wrong number of labels predicted.");
                    LOGGER.error("", (Throwable)e);
                    throw e;
                }
                double[] conf = null;
                if (showSegmentConfidence) {
                    conf = this.getSegmentConfidence(input, output);
                }
                for (int i = 0; i < sentence.getUnits().size(); ++i) {
                    Unit unit = sentence.get(i);
                    unit.setLabel((String)output.get(i));
                    if (!showSegmentConfidence) continue;
                    unit.setConfidence(conf[i]);
                }
                break block7;
            }
            if (!this.max_ent) break block7;
            System.out.println("  * predicting with me model...");
            Classifier classifier = (Classifier)this.model;
            Instance inst = this.generalPipe.instanceFrom(new Instance((Object)sentence, (Object)"", (Object)"", (Object)""));
            InstanceList tokenList = FeatureGenerator.convertFeatsforClassifier(classifier.getInstancePipe(), inst);
            LOGGER.info("current sentence has this number of token features: " + tokenList.size());
            ArrayList<Unit> units = sentence.getUnits();
            if (units.size() != tokenList.size()) {
                LOGGER.error("precit() - something went wrong with sequence feature conversion");
                System.exit(-1);
            }
            for (int j = 0; j < tokenList.size(); ++j) {
                Classification C = classifier.classify((Instance)tokenList.get(j));
                String label = C.getLabeling().getBestLabel().toString();
                C.getLabeling().getBestValue();
                Unit unit = units.get(j);
                unit.setLabel(label);
            }
        }
    }

    public ArrayList<String> predictIOB(ArrayList<Sentence> sentences, boolean showSegmentConfidence) {
        if (!this.trained || this.model == null) {
            IllegalStateException e = new IllegalStateException("no model available. Train or load trained model first.");
            LOGGER.error("", (Throwable)e);
            throw e;
        }
        long t1 = System.currentTimeMillis();
        ArrayList<String> iobList = new ArrayList<String>();
        if (!this.max_ent) {
            System.out.println("  * predicting with crf model...");
            for (int i = 0; i < sentences.size(); ++i) {
                Sentence sentence = sentences.get(i);
                Instance inst = ((Transducer)this.model).getInputPipe().instanceFrom(new Instance((Object)sentence, (Object)"", (Object)"", (Object)""));
                Sequence input = (Sequence)inst.getData();
                Sequence output = ((Transducer)this.model).transduce(input);
                ArrayList<Unit> units = sentence.getUnits();
                if (output.size() != sentence.getUnits().size()) {
                    IllegalStateException e = new IllegalStateException("Wrong number of labels predicted.");
                    LOGGER.error("", (Throwable)e);
                    throw e;
                }
                double[] conf = null;
                if (showSegmentConfidence) {
                    conf = this.getSegmentConfidence(input, output);
                }
                for (int j = 0; j < sentence.getUnits().size(); ++j) {
                    Unit unit = sentence.get(j);
                    unit.setLabel((String)output.get(j));
                    String iobString = units.get(j).getRep() + "\t" + (String)output.get(j);
                    if (showSegmentConfidence) {
                        unit.setConfidence(conf[j]);
                        iobString = iobString + "\t" + conf[j];
                    }
                    iobList.add(iobString);
                }
                iobList.add("O\tO");
            }
        } else if (this.max_ent) {
            int i;
            System.out.println("  * predicting with me model...");
            Classifier classifier = (Classifier)this.model;
            InstanceList instanceList = new InstanceList(this.generalPipe);
            for (i = 0; i < sentences.size(); ++i) {
                Sentence sentence = sentences.get(i);
                Instance inst = this.generalPipe.instanceFrom(new Instance((Object)sentence, (Object)"", (Object)"", (Object)""));
                instanceList.add(inst);
            }
            for (i = 0; i < instanceList.size(); ++i) {
                Instance inst = (Instance)instanceList.get(i);
                InstanceList tokenList = FeatureGenerator.convertFeatsforClassifier(classifier.getInstancePipe(), inst);
                LOGGER.info("current sentence has this number of token features: " + tokenList.size());
                Sentence sentence = sentences.get(i);
                ArrayList<Unit> units = sentence.getUnits();
                if (units.size() != tokenList.size()) {
                    LOGGER.error("precit() - something went wrong with sequence feature conversion");
                    System.exit(-1);
                }
                for (int j = 0; j < tokenList.size(); ++j) {
                    Classification C = classifier.classify((Instance)tokenList.get(j));
                    String label = C.getLabeling().getBestLabel().toString();
                    C.getLabeling().getBestValue();
                    Unit unit = units.get(j);
                    unit.setLabel(label);
                    String iobString = units.get(j).getRep() + "\t" + label;
                    iobList.add(iobString);
                }
                iobList.add("O\tO");
            }
        }
        long t2 = System.currentTimeMillis();
        System.out.println("prediction took: " + (t2 - t1));
        return iobList;
    }

    private double[] getSegmentConfidence(Sequence<?> input, Sequence<?> output) {
        double[] confidenceList = new double[output.size()];
        for (int i = 0; i < confidenceList.length; ++i) {
            confidenceList[i] = -1.0;
        }
        ArrayList<String> labels = new ArrayList<String>();
        for (int i = 0; i < output.size(); ++i) {
            labels.add((String)output.get(i));
        }
        HashMap<String, String> entities = IOEvaluation.getChunksIO(labels);
        Iterator<String> iterator = entities.keySet().iterator();
        while (iterator.hasNext()) {
            String name;
            String key = name = iterator.next();
            String entLabel = entities.get(key).split("#")[0];
            String[] offset = key.split(",");
            int start = new Integer(offset[0]);
            int stop = new Integer(offset[1]);
            Segment seg = new Segment(input, output, output, start, stop, (Object)entLabel, (Object)entLabel);
            double constrFBConf = this.estimateConfidenceFor(seg, null);
            for (int i = start; i <= stop; ++i) {
                confidenceList[i] = constrFBConf;
            }
        }
        return confidenceList;
    }

    private double estimateConfidenceFor(Segment segment, SumLatticeDefault cachedLattice) {
        Sequence predSequence = segment.getPredicted();
        Sequence input = segment.getInput();
        SumLatticeDefault lattice = cachedLattice == null ? new SumLatticeDefault((Transducer)this.model, input) : cachedLattice;
        SumLatticeConstrained constrainedLattice = new SumLatticeConstrained((Transducer)this.model, input, null, segment, predSequence);
        double latticeWeight = lattice.getTotalWeight();
        double constrainedLatticeWeight = constrainedLattice.getTotalWeight();
        double confidence = Math.exp(constrainedLatticeWeight - latticeWeight);
        return confidence;
    }

    public void writeModel(String filename) {
        if (!this.trained || this.model == null || this.featureConfig == null) {
            System.err.println("train or load trained model first.");
            System.exit(0);
        }
        try {
            FileOutputStream fos = new FileOutputStream(new File(filename + ".gz"));
            GZIPOutputStream gout = new GZIPOutputStream(fos);
            ObjectOutputStream oos = new ObjectOutputStream(gout);
            oos.writeObject(new FeatureSubsetModel(this.model, this.featureConfig));
            oos.close();
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(-1);
        }
    }

    public void readModel(File f) throws IOException, FileNotFoundException, ClassNotFoundException {
        this.readModel(new FileInputStream(f));
    }

    public void readModel(InputStream is) throws IOException, FileNotFoundException, ClassNotFoundException {
        GZIPInputStream gin = new GZIPInputStream(is);
        ObjectInputStream ois = new ObjectInputStream(gin);
        FeatureSubsetModel fsm = (FeatureSubsetModel)ois.readObject();
        ois.close();
        this.model = fsm.getModel();
        this.featureConfig = fsm.getFeatureConfig();
        this.trained = true;
        if (this.model instanceof MaxEnt) {
            ((MaxEnt)this.model).getInstancePipe().getDataAlphabet().stopGrowth();
            this.max_ent = true;
        } else {
            ((Transducer)this.model).getInputPipe().getDataAlphabet().stopGrowth();
        }
    }

    public Object getModel() {
        return this.model;
    }

    public void setFeatureConfig(Properties featureConfig) {
        this.featureConfig = featureConfig;
    }

    public Properties getFeatureConfig() {
        return this.featureConfig;
    }

    public Sentence PPDtoUnits(String sentence) {
        String[] tokens = sentence.trim().split("[\t ]+");
        ArrayList<Unit> units = new ArrayList<Unit>();
        FeatureConfiguration fc = new FeatureConfiguration();
        String[] trueMetas = fc.getTrueMetas(this.featureConfig);
        for (String token : tokens) {
            HashMap<String, String> metas = new HashMap<String, String>();
            String[] features = token.split("\\|+");
            String word = features[0];
            String label = features[features.length - 1];
            if (trueMetas.length + 2 != features.length) {
                System.err.println("Error in input format (PipedFormat)! Mal-formatted sentence: " + sentence + "\n token: " + token);
                System.err.println("Check your configuration file. Most probably you use more or less meta-data as specified in the configuration file.\nIf you don't use a config file, you should check whether your input files fit to the default configuration.");
                System.exit(-1);
            }
            for (String trueMeta : trueMetas) {
                int position = Integer.parseInt(this.featureConfig.getProperty(trueMeta + "_feat_position"));
                String featureName = this.featureConfig.getProperty(trueMeta + "_feat_unit");
                if (features[position].equals(this.featureConfig.getProperty("gap_character"))) continue;
                metas.put(featureName, features[position]);
            }
            Unit unit = new Unit(0, 0, word, label, metas);
            units.add(unit);
        }
        return new Sentence(units);
    }

    public int getNumber_Iterations() {
        return this.number_iterations;
    }

    public void set_Number_Iterations(int number_iter) {
        this.number_iterations = number_iter;
    }

    public boolean is_Max_Ent() {
        return this.max_ent;
    }

    public void set_Max_Ent(boolean me_train) {
        this.max_ent = me_train;
    }
}

