/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.ml.classifier.perceptron;

import com.google.common.collect.Lists;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.List;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.classifier.LabelIndexer;
import org.maochen.nlp.ml.util.ModelSerializeUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PerceptronModel {
    private static final Logger LOG = LoggerFactory.getLogger(PerceptronModel.class);
    double learningRate = 0.1;
    double threshold = 0.5;
    double[] bias = null;
    double[][] weights = null;
    LabelIndexer labelIndexer;

    public PerceptronModel() {
    }

    public PerceptronModel(PerceptronModel model) {
        this.learningRate = model.learningRate;
        this.threshold = model.threshold;
        this.bias = Arrays.copyOf(model.bias, model.bias.length);
        this.labelIndexer = model.labelIndexer;
        this.weights = new double[model.weights.length][];
        for (int i = 0; i < model.weights.length; ++i) {
            double[] aMatrix = model.weights[i];
            this.weights[i] = new double[aMatrix.length];
            System.arraycopy(aMatrix, 0, this.weights[i], 0, aMatrix.length);
        }
    }

    public void init(List<Tuple> trainingData, boolean initWeightRandom) {
        this.labelIndexer = new LabelIndexer(trainingData);
        int featurelength = ((Tuple)trainingData.stream().findFirst().orElse(null)).vector.getVector().length;
        this.weights = new double[this.labelIndexer.getLabelSize()][featurelength];
        this.bias = new double[this.labelIndexer.getLabelSize()];
        if (initWeightRandom) {
            for (int i = 0; i < this.weights.length; ++i) {
                for (int j = 0; j < this.weights[i].length; ++j) {
                    this.weights[i][j] = Math.random();
                }
            }
        }
    }

    public void persist(String filename) {
        try (BufferedWriter output = new BufferedWriter(new FileWriter(new File(filename)));){
            output.write(String.valueOf(this.learningRate));
            output.write(System.lineSeparator());
            output.write(String.valueOf(this.threshold));
            output.write(System.lineSeparator());
            output.write(ModelSerializeUtils.oneDimensionArraySerialize(this.bias));
            output.write(ModelSerializeUtils.twoDimensionalArraySerialize(this.weights));
            output.write("li" + System.lineSeparator());
            output.write(ModelSerializeUtils.mapSerialize(this.labelIndexer.labelIndexer.entrySet()));
        }
        catch (IOException e) {
            LOG.error("Persist model err.", (Throwable)e);
        }
    }

    public void load(InputStream is) {
        try (BufferedReader br = new BufferedReader(new InputStreamReader(is));){
            String line;
            int lineCount = 0;
            boolean isLabelIndexer = false;
            while ((line = br.readLine()) != null) {
                ++lineCount;
                if ((line = line.trim()).isEmpty()) continue;
                if (lineCount == 1) {
                    this.learningRate = Double.valueOf(line);
                    continue;
                }
                if (lineCount == 2) {
                    this.threshold = Double.valueOf(line);
                    continue;
                }
                if (lineCount == 4) {
                    this.bias = Arrays.stream(line.split("\\s")).mapToDouble(Double::parseDouble).toArray();
                    continue;
                }
                if (lineCount == 5) {
                    int rows = Integer.parseInt(line.split("\\s")[0]);
                    this.weights = new double[rows][];
                    ++lineCount;
                    while (lineCount < rows + 6) {
                        line = br.readLine().trim();
                        this.weights[lineCount - 6] = Arrays.stream(line.split("\\s")).mapToDouble(Double::parseDouble).toArray();
                        ++lineCount;
                    }
                    continue;
                }
                if (line.equalsIgnoreCase("li")) {
                    isLabelIndexer = true;
                    this.labelIndexer = new LabelIndexer(Lists.newArrayList());
                    continue;
                }
                if (!isLabelIndexer) continue;
                this.labelIndexer.labelIndexer.put((Object)line.split("\\s")[0], (Object)Integer.parseInt(line.split("\\s")[1]));
            }
        }
        catch (IOException e) {
            LOG.error("Load model err.", (Throwable)e);
        }
    }
}

