/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.gene.candidateretrieval.scoring;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Token2FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.Labeling;
import de.julielab.gene.candidateretrieval.scoring.MaxEntScorerFeaturePipe;
import java.util.ArrayList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MaxEntScorerML {
    private static final Logger LOGGER = LoggerFactory.getLogger(MaxEntScorerML.class);

    public InstanceList makeInstances(ArrayList<String[]> pairList, Pipe pipe) {
        LOGGER.debug("makeInstances() - making instances for pairs with old pipe ...");
        InstanceList iList = new InstanceList(pipe);
        for (int i = 0; i < pairList.size(); ++i) {
            iList.addThruPipe(new Instance((Object)pairList.get(i), (Object)"", (Object)"", (Object)""));
        }
        return iList;
    }

    public InstanceList makeInstances(ArrayList<String[]> pairList) {
        LOGGER.debug("makeInstances() - making instances for pairs with new pipe ...");
        SerialPipes pipe = new SerialPipes(new Pipe[]{new MaxEntScorerFeaturePipe(), new Token2FeatureVector()});
        InstanceList iList = new InstanceList((Pipe)pipe);
        for (int i = 0; i < pairList.size(); ++i) {
            iList.addThruPipe(new Instance((Object)pairList.get(i), (Object)"", (Object)"", (Object)""));
        }
        return iList;
    }

    public Classifier train(InstanceList iList) {
        LOGGER.debug("train() - training the model from " + iList.size() + " training examples ...");
        MaxEntTrainer trainer = new MaxEntTrainer();
        MaxEnt meModel = trainer.train(iList);
        return meModel;
    }

    public double predict(Instance inst, Classifier model) {
        return this.getProbabilityTrueClass(model.classify(inst));
    }

    public void eval(Classifier model, InstanceList pairList) {
        ArrayList classifications = model.classify(pairList);
        for (Classification c : classifications) {
            Labeling labeling = c.getLabeling();
            double predValue = this.getProbabilityTrueClass(c);
            System.out.println("           pair: " + c.getInstance().getSource());
            System.out.println("predicted score: " + predValue);
            System.out.println("  correct class: " + c.getInstance().getName());
            System.out.println("predicted class: " + labeling.getBestLabel() + "\n");
        }
    }

    private double getProbabilityTrueClass(Classification c) {
        Labeling labeling = c.getLabeling();
        LabelAlphabet dict = labeling.getLabelAlphabet();
        Label label = dict.lookupLabel((Object)"TRUE");
        double predValue = labeling.value(label);
        return predValue;
    }
}

