/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.ml.classifier.maxent;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import opennlp.maxent.GISModel;
import opennlp.maxent.io.GISModelReader;
import opennlp.maxent.io.PlainTextGISModelWriter;
import opennlp.model.AbstractModel;
import opennlp.model.DataReader;
import opennlp.model.PlainTextFileDataReader;
import opennlp.model.Prior;
import opennlp.model.RealValueFileEventStream;
import opennlp.model.UniformPrior;
import org.maochen.nlp.ml.IClassifier;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.classifier.maxent.GISTrainer;
import org.maochen.nlp.ml.classifier.maxent.OnePassRealValueDataIndexer;
import org.maochen.nlp.ml.classifier.maxent.eventstream.EventStream;
import org.maochen.nlp.ml.classifier.maxent.eventstream.StringEventStream;
import org.maochen.nlp.ml.classifier.maxent.eventstream.TupleEventStream;
import org.maochen.nlp.ml.vector.IVector;
import org.maochen.nlp.ml.vector.LabeledVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MaxEntClassifier
implements IClassifier {
    private static final Logger LOG = LoggerFactory.getLogger(MaxEntClassifier.class);
    private boolean useSmoothing = true;
    private int iterations = 100;
    private int cutoff = 0;
    private int nthreads = Runtime.getRuntime().availableProcessors();
    private double smoothingObservation = 0.1;
    private GISModel model = null;

    public MaxEntClassifier trainString(List<String[]> trainingData) {
        StringEventStream es = new StringEventStream(trainingData);
        return this.train(es);
    }

    private MaxEntClassifier train(EventStream es) {
        UniformPrior prior = new UniformPrior();
        OnePassRealValueDataIndexer di = new OnePassRealValueDataIndexer(es, this.cutoff, true);
        GISTrainer gisTrainer = new GISTrainer();
        gisTrainer.setSmoothing(this.useSmoothing);
        gisTrainer.setSmoothingObservation(this.smoothingObservation);
        this.model = gisTrainer.trainModel(this.iterations, di, (Prior)prior, this.cutoff, this.nthreads);
        return this;
    }

    public Map<String, Double> predict(String[] feats) {
        float[] val = RealValueFileEventStream.parseContexts((String[])feats);
        double[] vector = new double[val.length];
        for (int i = 0; i < val.length; ++i) {
            vector[i] = val[i];
        }
        LabeledVector labeledVector = new LabeledVector(vector);
        Tuple predict = new Tuple((IVector)labeledVector);
        labeledVector.featsName = feats;
        return this.predict(predict);
    }

    public IClassifier train(List<Tuple> trainingData) {
        TupleEventStream es = new TupleEventStream(trainingData);
        return this.train(es);
    }

    public Map<String, Double> predict(Tuple predict) {
        if (!(predict.vector instanceof LabeledVector)) {
            throw new IllegalArgumentException("Please use LabeledVector");
        }
        float[] featureVector = new float[predict.vector.getVector().length];
        for (int i = 0; i < featureVector.length; ++i) {
            featureVector[i] = (float)predict.vector.getVector()[i];
        }
        double[] prob = this.model.eval(((LabeledVector)predict.vector).featsName, featureVector, new double[this.model.getNumOutcomes()]);
        HashMap<String, Double> resultMap = new HashMap<String, Double>();
        for (int i = 0; i < prob.length; ++i) {
            resultMap.put(this.model.getOutcome(i), prob[i]);
        }
        return resultMap;
    }

    public void setParameter(Map<String, String> paraMap) {
        if (paraMap == null) {
            return;
        }
        if (paraMap.containsKey("use_smoothing")) {
            this.useSmoothing = Boolean.valueOf(paraMap.get("use_smoothing"));
        }
        if (paraMap.containsKey("iterations")) {
            this.iterations = Integer.parseInt(paraMap.get("iterations"));
        }
        if (paraMap.containsKey("cutoff")) {
            this.cutoff = Integer.parseInt(paraMap.get("cutoff"));
        }
        if (paraMap.containsKey("nthreads")) {
            this.nthreads = Integer.parseInt(paraMap.get("nthreads"));
        }
        if (paraMap.containsKey("smoothing_observation")) {
            this.smoothingObservation = Double.parseDouble(paraMap.get("smoothing_observation"));
        }
    }

    public void persistModel(String modelPath) throws IOException {
        File outputFile = new File(modelPath);
        PlainTextGISModelWriter writer = new PlainTextGISModelWriter((AbstractModel)this.model, outputFile);
        writer.persist();
    }

    public void loadModel(InputStream modelPath) {
        LOG.info("Loading MaxEnt model.");
        GISModelReader modelReader = new GISModelReader((DataReader)new PlainTextFileDataReader(modelPath));
        try {
            AbstractModel model = modelReader.getModel();
            this.model = (GISModel)model;
        }
        catch (IOException e) {
            LOG.error("model load err.", (Throwable)e);
        }
    }
}

