/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.ml.classifier.libsvm;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.NotImplementedException;
import org.maochen.nlp.ml.IClassifier;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.classifier.LabelIndexer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LibSVMClassifier
implements IClassifier {
    private static final Logger LOG = LoggerFactory.getLogger(LibSVMClassifier.class);
    private svm_model model = null;
    public svm_parameter para = null;
    private LabelIndexer labelIndexer = null;

    private void writeToLog() {
        svm.svm_set_print_string_function(x -> {
            if (!".".equals(x)) {
                LOG.info(x);
            }
        });
    }

    public svm_parameter getDefaultPara() {
        svm_parameter para = new svm_parameter();
        para.probability = 1;
        para.gamma = 0.5;
        para.nu = 0.5;
        para.C = 100.0;
        para.svm_type = 0;
        para.kernel_type = 0;
        para.cache_size = 20000.0;
        para.eps = 0.001;
        para.p = 0.1;
        return para;
    }

    public IClassifier train(List<Tuple> trainingData) {
        if (this.para == null) {
            LOG.warn("Parameter is null. Use the default parameter.");
            this.para = this.getDefaultPara();
        }
        this.labelIndexer = new LabelIndexer(trainingData);
        svm_problem prob = new svm_problem();
        int featSize = trainingData.iterator().next().vector.getVector().length;
        prob.l = trainingData.size();
        prob.y = new double[prob.l];
        prob.x = new svm_node[prob.l][featSize];
        for (int i = 0; i < trainingData.size(); ++i) {
            Tuple tuple = trainingData.get(i);
            prob.x[i] = new svm_node[featSize];
            for (int j = 0; j < tuple.vector.getVector().length; ++j) {
                svm_node node = new svm_node();
                node.index = j;
                node.value = tuple.vector.getVector()[j];
                prob.x[i][j] = node;
            }
            prob.y[i] = this.labelIndexer.getIndex(tuple.label);
        }
        this.model = svm.svm_train((svm_problem)prob, (svm_parameter)this.para);
        return this;
    }

    public Map<String, Double> predict(Tuple predict) {
        double[] feats = predict.vector.getVector();
        svm_node[] svmfeats = new svm_node[feats.length];
        for (int i = 0; i < feats.length; ++i) {
            svm_node svmfeatI = new svm_node();
            svmfeatI.index = i;
            svmfeatI.value = feats[i];
            svmfeats[i] = svmfeatI;
        }
        int totalSize = this.labelIndexer.getLabelSize();
        int[] labels = new int[totalSize];
        svm.svm_get_labels((svm_model)this.model, (int[])labels);
        double[] probs = new double[totalSize];
        svm.svm_predict_probability((svm_model)this.model, (svm_node[])svmfeats, (double[])probs);
        HashMap<String, Double> result = new HashMap<String, Double>();
        for (int i = 0; i < labels.length; ++i) {
            result.put(this.labelIndexer.getLabel(labels[i]), probs[i]);
        }
        return result;
    }

    public void setParameter(Properties props) {
        throw new NotImplementedException("Use direct set para for now.");
    }

    public void persistModel(String modelFile) throws IOException {
        if (this.labelIndexer == null) {
            throw new RuntimeException("LabelIndexer is null!");
        }
        ZipOutputStream zipos = new ZipOutputStream(new FileOutputStream(modelFile));
        String svmModelAbsolutePath = modelFile + ".model";
        String svmModelFilename = new File(svmModelAbsolutePath).getName();
        svm.svm_save_model((String)svmModelAbsolutePath, (svm_model)this.model);
        ZipEntry libSVMModelZipEntry = new ZipEntry(svmModelFilename);
        zipos.putNextEntry(libSVMModelZipEntry);
        IOUtils.copy((InputStream)new FileInputStream(svmModelAbsolutePath), (OutputStream)zipos);
        zipos.closeEntry();
        String labelIndexerAbsolutePath = modelFile + ".lbindexer";
        String labelIndexerFileName = new File(labelIndexerAbsolutePath).getName();
        String labelIndexerString = this.labelIndexer.serializeToString();
        ZipEntry labelIndexerZipEntry = new ZipEntry(labelIndexerFileName);
        zipos.putNextEntry(labelIndexerZipEntry);
        IOUtils.write((String)labelIndexerString, (OutputStream)zipos, (Charset)Charset.defaultCharset());
        zipos.closeEntry();
        IOUtils.closeQuietly((OutputStream)zipos);
        FileUtils.forceDelete((File)new File(svmModelAbsolutePath));
    }

    public void loadModel(InputStream modelIs) {
        ZipEntry entry2;
        Throwable throwable;
        ZipInputStream zipInputStream2;
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        try {
            IOUtils.copy((InputStream)modelIs, (OutputStream)baos);
        }
        catch (IOException e) {
            LOG.error("Load model err.", (Throwable)e);
        }
        ByteArrayInputStream isForSVMLoad = new ByteArrayInputStream(baos.toByteArray());
        try {
            zipInputStream2 = new ZipInputStream(isForSVMLoad);
            throwable = null;
            try {
                while ((entry2 = zipInputStream2.getNextEntry()) != null) {
                    if (!entry2.getName().endsWith(".model")) continue;
                    BufferedReader br = new BufferedReader(new InputStreamReader((InputStream)zipInputStream2, Charset.defaultCharset()));
                    this.model = svm.svm_load_model((BufferedReader)br);
                }
            }
            catch (Throwable entry2) {
                throwable = entry2;
                throw entry2;
            }
            finally {
                if (zipInputStream2 != null) {
                    if (throwable != null) {
                        try {
                            zipInputStream2.close();
                        }
                        catch (Throwable entry2) {
                            throwable.addSuppressed(entry2);
                        }
                    } else {
                        zipInputStream2.close();
                    }
                }
            }
        }
        catch (IOException zipInputStream2) {
            // empty catch block
        }
        modelIs = new ByteArrayInputStream(baos.toByteArray());
        try {
            zipInputStream2 = new ZipInputStream(modelIs);
            throwable = null;
            try {
                while ((entry2 = zipInputStream2.getNextEntry()) != null) {
                    if (!entry2.getName().endsWith(".lbindexer")) continue;
                    String lbIndexer = IOUtils.toString((InputStream)zipInputStream2, (Charset)Charset.defaultCharset());
                    this.labelIndexer = new LabelIndexer(new ArrayList<Tuple>());
                    this.labelIndexer.readFromSerializedString(lbIndexer);
                }
            }
            catch (Throwable throwable2) {
                throwable = throwable2;
                throw throwable2;
            }
            finally {
                if (zipInputStream2 != null) {
                    if (throwable != null) {
                        try {
                            zipInputStream2.close();
                        }
                        catch (Throwable throwable3) {
                            throwable.addSuppressed(throwable3);
                        }
                    } else {
                        zipInputStream2.close();
                    }
                }
            }
        }
        catch (IOException e) {
            LOG.error("Err in load LabelIndexer", (Throwable)e);
        }
    }

    public LibSVMClassifier() {
        this.writeToLog();
    }
}

