/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.ml.classifier.perceptron;

import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.maochen.nlp.ml.IClassifier;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.classifier.perceptron.PerceptronModel;
import org.maochen.nlp.util.VectorUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PerceptronClassifier
implements IClassifier {
    private static final Logger LOG = LoggerFactory.getLogger(PerceptronClassifier.class);
    protected PerceptronModel model = new PerceptronModel();
    private Properties properties = null;
    private static int MAX_ITERATION = 200;

    private Map<Integer, Double> predict(double[] x) {
        HashMap<Integer, Double> result = new HashMap<Integer, Double>();
        for (int i = 0; i < this.model.weights.length; ++i) {
            double y = VectorUtils.dotProduct(x, this.model.weights[i]);
            result.put(i, y += this.model.bias[i]);
        }
        return result;
    }

    private Pair<Integer, Double> predictMax(double[] x) {
        Map.Entry result = this.predict(x).entrySet().stream().max((e1, e2) -> ((Double)e1.getValue()).compareTo((Double)e2.getValue())).orElse(null);
        return result == null ? null : new ImmutablePair(result.getKey(), result.getValue());
    }

    private double[] reweight(double[] x, double[] weight, double correctionD) {
        return IntStream.range(0, x.length).mapToDouble(i -> weight[i] + this.model.learningRate * correctionD * x[i]).toArray();
    }

    public void onlineTrain(double[] x, int labelIndex) {
        Map<Integer, Double> result = this.predict(x);
        Map.Entry maxResult = result.entrySet().stream().max((e1, e2) -> ((Double)e1.getValue()).compareTo((Double)e2.getValue())).orElse(null);
        if ((Integer)maxResult.getKey() != labelIndex) {
            double e_correction_d = 1.0;
            this.model.weights[labelIndex] = this.reweight(x, this.model.weights[labelIndex], e_correction_d);
            this.model.bias[labelIndex] = e_correction_d;
            double w_correction_d = -1.0;
            this.model.weights[((Integer)maxResult.getKey()).intValue()] = this.reweight(x, this.model.weights[(Integer)maxResult.getKey()], w_correction_d);
            this.model.bias[((Integer)maxResult.getKey()).intValue()] = w_correction_d;
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("New bias: " + Arrays.toString(this.model.bias));
            LOG.debug("New weight: " + Arrays.stream(this.model.weights).map(Arrays::toString).reduce((wi, wii) -> wi + ", " + wii).get());
        }
    }

    public IClassifier train(List<Tuple> trainingData) {
        int errCount;
        this.model = new PerceptronModel(trainingData);
        this.setParameter(this.properties);
        int iter = 0;
        do {
            LOG.debug("Iteration " + ++iter);
            Collections.shuffle(trainingData);
            for (Tuple entry2 : trainingData) {
                this.onlineTrain(entry2.vector.getVector(), this.model.labelIndexer.getIndex(entry2.label));
            }
        } while ((errCount = (int)trainingData.stream().filter(entry -> ((Integer)this.predictMax(entry.vector.getVector()).getLeft()).intValue() != this.model.labelIndexer.getIndex(entry.label)).count()) != 0 && iter < MAX_ITERATION);
        LOG.debug("Err size: " + errCount);
        return this;
    }

    public Map<String, Double> predict(Tuple predict) {
        Map<Integer, Double> indexResult = this.predict(predict.vector.getVector());
        return indexResult.entrySet().stream().map(e -> new ImmutablePair((Object)this.model.labelIndexer.getLabel((Integer)e.getKey()), (Object)VectorUtils.sigmoid.apply((Double)e.getValue()))).collect(Collectors.toMap(ImmutablePair::getLeft, ImmutablePair::getRight));
    }

    public void setParameter(Properties props) {
        this.properties = props;
        if (props == null) {
            return;
        }
        if (props.containsKey("learning_rate")) {
            this.model.learningRate = Double.parseDouble(props.getProperty("learning_rate"));
        }
        if (props.containsKey("iter")) {
            MAX_ITERATION = Integer.parseInt(props.getProperty("iter"));
        }
        if (props.containsKey("threshold")) {
            this.model.threshold = Double.parseDouble(props.getProperty("threshold"));
        }
    }

    public void persistModel(String modelFile) throws IOException {
        this.model.persist(modelFile);
    }

    public void loadModel(InputStream inputStream) {
        this.model.load(inputStream);
    }
}

