/*
 * Decompiled with CFR 0.152.
 */
package se.lth.cs.srl.pipeline;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import se.lth.cs.srl.corpus.ArgMap;
import se.lth.cs.srl.corpus.Predicate;
import se.lth.cs.srl.corpus.Sentence;
import se.lth.cs.srl.corpus.Word;
import se.lth.cs.srl.features.FeatureSet;
import se.lth.cs.srl.ml.Model;
import se.lth.cs.srl.ml.liblinear.Label;
import se.lth.cs.srl.pipeline.ArgumentStep;

public class ArgumentClassifier
extends ArgumentStep {
    private static final String FILEPREFIX = "ac_";
    private List<String> argLabels;

    public ArgumentClassifier(FeatureSet fs, List<String> argLabels) {
        super(fs);
        this.argLabels = argLabels;
    }

    @Override
    public void extractInstances(Sentence s) {
        for (Predicate pred : s.getPredicates()) {
            for (Word arg : pred.getArgMap().keySet()) {
                super.addInstance(pred, arg);
            }
        }
    }

    @Override
    public void parse(Sentence s) {
        for (Predicate pred : s.getPredicates()) {
            Map<Word, String> argMap = pred.getArgMap();
            for (Word arg : argMap.keySet()) {
                Integer label = super.classifyInstance(pred, arg);
                argMap.put(arg, this.argLabels.get(label));
            }
        }
    }

    @Override
    protected Integer getLabel(Predicate pred, Word arg) {
        return this.argLabels.indexOf(pred.getArgMap().get(arg));
    }

    @Override
    public void prepareLearning() {
        super.prepareLearning(FILEPREFIX);
    }

    @Override
    protected String getModelFileName() {
        return "ac_.models";
    }

    List<ArgMap> beamSearch(Predicate pred, List<ArgMap> candidates, int beamSize) {
        ArrayList<ArgMap> ret = new ArrayList<ArgMap>();
        String POSPrefix = super.getPOSPrefix(pred.getPOS());
        if (POSPrefix == null) {
            POSPrefix = this.featureSet.POSPrefixes[0];
        }
        Model model = (Model)this.models.get(POSPrefix);
        for (ArgMap argMap : candidates) {
            ArrayList<ArgMap> branches = new ArrayList<ArgMap>();
            branches.add(argMap);
            TreeSet<ArgMap> newBranches = new TreeSet<ArgMap>(ArgMap.REVERSE_PROB_COMPARATOR);
            for (Word arg : argMap.keySet()) {
                Collection<Integer> indices = super.collectIndices(pred, arg, POSPrefix, new TreeSet<Integer>());
                List<Label> probs = model.classifyProb(indices);
                for (ArgMap branch : branches) {
                    for (int i = 0; i < beamSize; ++i) {
                        Label label = probs.get(i);
                        ArgMap newBranch = new ArgMap(branch);
                        newBranch.put(arg, this.argLabels.get(label.getLabel()), label.getProb());
                        newBranches.add(newBranch);
                    }
                }
                branches.clear();
                Iterator it = newBranches.iterator();
                for (int i = 0; i < beamSize && it.hasNext(); ++i) {
                    ArgMap cur = (ArgMap)it.next();
                    branches.add(cur);
                }
                newBranches.clear();
            }
            int size = branches.size();
            for (int i = 0; i < beamSize && i < size; ++i) {
                ArgMap cur = (ArgMap)branches.get(i);
                cur.setLblProb(cur.getProb());
                cur.resetProb();
                ret.add(cur);
            }
        }
        return ret;
    }
}

