/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.ml.classifier.crfsuite;

import com.github.jcrfsuite.util.CrfSuiteLoader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.CopyOption;
import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.maochen.nlp.ml.ISeqClassifier;
import org.maochen.nlp.ml.SequenceTuple;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.vector.LabeledVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import third_party.org.chokkan.crfsuite.Attribute;
import third_party.org.chokkan.crfsuite.Item;
import third_party.org.chokkan.crfsuite.ItemSequence;
import third_party.org.chokkan.crfsuite.StringList;
import third_party.org.chokkan.crfsuite.Tagger;
import third_party.org.chokkan.crfsuite.Trainer;

public class CRFClassifier
implements ISeqClassifier {
    private static final Logger LOG = LoggerFactory.getLogger(CRFClassifier.class);
    private Properties props = new Properties();
    private String modelPath = null;
    private Tagger tagger = null;
    private static final String DEFAULT_ALGORITHM = "lbfgs";
    private static final String DEFAULT_GRAPHICAL_MODEL_TYPE = "crf1d";

    private static Pair<List<ItemSequence>, List<StringList>> loadTrainingData(List<SequenceTuple> trainingData) {
        ArrayList<ItemSequence> xseqs = new ArrayList<ItemSequence>();
        ArrayList<StringList> yseqs = new ArrayList<StringList>();
        for (SequenceTuple sequenceTuple : trainingData) {
            xseqs.add(CRFClassifier.getXseqForOneSeqTuple(sequenceTuple));
            StringList yseq = new StringList();
            sequenceTuple.getLabel().stream().forEach(arg_0 -> ((StringList)yseq).add(arg_0));
            yseqs.add(yseq);
        }
        return new ImmutablePair(xseqs, yseqs);
    }

    private static ItemSequence getXseqForOneSeqTuple(SequenceTuple sequenceTuple) {
        ItemSequence xseq = new ItemSequence();
        for (Tuple t : sequenceTuple.entries) {
            Item item = new Item();
            for (int i = 0; i < t.vector.getVector().length; ++i) {
                Attribute attr = t.vector instanceof LabeledVector ? new Attribute(((LabeledVector)t.vector).featsName[i]) : new Attribute(String.valueOf(i), t.vector.getVector()[i]);
                item.add(attr);
            }
            xseq.add(item);
        }
        return xseq;
    }

    public ISeqClassifier train(List<SequenceTuple> trainingData) {
        if (trainingData == null || trainingData.size() == 0) {
            LOG.warn("Training data is empty.");
            return this;
        }
        if (this.modelPath == null) {
            try {
                this.modelPath = Files.createTempDirectory("crfsuite", new FileAttribute[0]).toAbsolutePath().toString();
            }
            catch (IOException e) {
                LOG.error("Create temp directory failed.", (Throwable)e);
                e.printStackTrace();
            }
        }
        Pair<List<ItemSequence>, List<StringList>> crfCompatibleTrainingData = CRFClassifier.loadTrainingData(trainingData);
        Trainer trainer = new Trainer();
        String algorithm = (String)this.props.getOrDefault((Object)"algorithm", DEFAULT_ALGORITHM);
        this.props.remove("algorithm");
        String graphicalModelType = (String)this.props.getOrDefault((Object)"graphicalModelType", DEFAULT_GRAPHICAL_MODEL_TYPE);
        this.props.remove("graphicalModelType");
        trainer.select(algorithm, graphicalModelType);
        this.props.entrySet().forEach(pair -> trainer.set((String)pair.getKey(), (String)pair.getValue()));
        for (int i = 0; i < trainingData.size(); ++i) {
            trainer.append((ItemSequence)((List)crfCompatibleTrainingData.getLeft()).get(i), (StringList)((List)crfCompatibleTrainingData.getRight()).get(i), 0);
        }
        trainer.train(this.modelPath, -1);
        return this;
    }

    public synchronized List<Pair<String, Double>> predict(SequenceTuple sequenceTuple) {
        if (this.tagger == null) {
            this.loadModel(null);
        }
        ArrayList<Pair<String, Double>> taggedSentences = new ArrayList<Pair<String, Double>>();
        this.tagger.set(CRFClassifier.getXseqForOneSeqTuple(sequenceTuple));
        StringList labels = this.tagger.viterbi();
        int i = 0;
        while ((long)i < labels.size()) {
            String label = labels.get(i);
            taggedSentences.add((Pair<String, Double>)new ImmutablePair((Object)label, (Object)this.tagger.marginal(label, i)));
            ++i;
        }
        return taggedSentences;
    }

    public void setParameter(Properties props) {
        this.modelPath = (String)props.getOrDefault((Object)"model", (Object)null);
        props.remove("model");
        this.props = props;
    }

    public void persistModel(String modelFile) throws IOException {
        if (this.modelPath.equals(modelFile)) {
            throw new IOException("same as original model path.");
        }
        File sourceFile = new File(this.modelPath);
        File destFile = new File(modelFile);
        Files.copy(sourceFile.toPath(), destFile.toPath(), new CopyOption[0]);
    }

    public Pair<Integer, Integer> validate(List<SequenceTuple> testingData) {
        int total = testingData.stream().mapToInt(st -> st.entries.size()).sum();
        int err = 0;
        for (SequenceTuple st2 : testingData) {
            List actual = this.predict(st2).stream().map(Pair::getLeft).collect(Collectors.toList());
            List expected = st2.getLabel();
            if (actual.size() != expected.size()) {
                throw new RuntimeException("Actual size: " + actual.size() + "\tExpected size: " + expected.size());
            }
            for (int i = 0; i < actual.size(); ++i) {
                if (((String)actual.get(i)).equals(expected.get(i))) continue;
                ++err;
            }
        }
        System.out.println("Err/Total: " + err + "/" + total);
        System.out.println("Accuracy: " + (1.0 - (double)err / (double)total) * 100.0 + "%");
        return new ImmutablePair((Object)err, (Object)total);
    }

    public void loadModel(InputStream modelFile) {
        if (this.modelPath == null) {
            throw new IllegalArgumentException("Please set model path parameter to load model");
        }
        this.tagger = new Tagger();
        boolean ret = this.tagger.open(this.modelPath);
        if (!ret) {
            LOG.error("Unable load model: " + this.modelPath);
        }
    }

    public CRFClassifier() {
    }

    public CRFClassifier(Properties props) {
        this.setParameter(props);
    }

    static {
        try {
            CrfSuiteLoader.load();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}

