/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.classifier.perceptron;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.maochen.nlp.classifier.IClassifier;
import org.maochen.nlp.classifier.perceptron.PerceptronModel;
import org.maochen.nlp.datastructure.Tuple;
import org.maochen.nlp.utils.VectorUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PerceptronClassifier
implements IClassifier {
    private static final Logger LOG = LoggerFactory.getLogger(PerceptronClassifier.class);
    protected PerceptronModel model = new PerceptronModel();
    private static final int MAX_ITERATION = 200;

    private Map<Integer, Double> predict(double[] x) {
        HashMap<Integer, Double> result = new HashMap<Integer, Double>();
        for (int i = 0; i < this.model.weights.length; ++i) {
            double y = VectorUtils.dotProduct(x, this.model.weights[i]);
            result.put(i, y += this.model.bias[i]);
        }
        return result;
    }

    private Pair<Integer, Double> predictMax(double[] x) {
        Map.Entry result = this.predict(x).entrySet().stream().max((e1, e2) -> ((Double)e1.getValue()).compareTo((Double)e2.getValue())).orElse(null);
        return result == null ? null : new ImmutablePair(result.getKey(), result.getValue());
    }

    private double[] reweight(double[] x, double[] weight, double correctionD) {
        return IntStream.range(0, x.length).mapToDouble(i -> weight[i] + this.model.learningRate * correctionD * x[i]).toArray();
    }

    public void onlineTrain(double[] x, int labelIndex) {
        Map<Integer, Double> result = this.predict(x);
        Map.Entry maxResult = result.entrySet().stream().max((e1, e2) -> ((Double)e1.getValue()).compareTo((Double)e2.getValue())).orElse(null);
        if ((Integer)maxResult.getKey() != labelIndex) {
            double e_correction_d = 1.0;
            this.model.weights[labelIndex] = this.reweight(x, this.model.weights[labelIndex], e_correction_d);
            this.model.bias[labelIndex] = e_correction_d;
            double w_correction_d = -1.0;
            this.model.weights[((Integer)maxResult.getKey()).intValue()] = this.reweight(x, this.model.weights[(Integer)maxResult.getKey()], w_correction_d);
            this.model.bias[((Integer)maxResult.getKey()).intValue()] = w_correction_d;
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("New bias: " + Arrays.toString(this.model.bias));
            LOG.debug("New weight: " + Arrays.stream(this.model.weights).map(Arrays::toString).reduce((wi, wii) -> wi + ", " + wii).get());
        }
    }

    @Override
    public IClassifier train(List<Tuple> trainingData) {
        int errCount;
        this.model = new PerceptronModel(trainingData);
        int iter = 0;
        do {
            LOG.info("Iteration " + ++iter);
            Collections.shuffle(trainingData);
            for (Tuple entry2 : trainingData) {
                this.onlineTrain(entry2.featureVector, this.model.labelIndexer.getIndex(entry2.label));
            }
        } while ((errCount = (int)trainingData.stream().filter(entry -> ((Integer)this.predictMax(entry.featureVector).getLeft()).intValue() != this.model.labelIndexer.getIndex(entry.label)).count()) != 0 && iter < 200);
        LOG.debug("Err size: " + errCount);
        return this;
    }

    @Override
    public Map<String, Double> predict(Tuple predict) {
        Map<Integer, Double> indexResult = this.predict(predict.featureVector);
        Map<String, Double> result = indexResult.entrySet().stream().map(e -> new ImmutablePair((Object)this.model.labelIndexer.getLabel((Integer)e.getKey()), (Object)VectorUtils.sigmoid.apply((Double)e.getValue()))).collect(Collectors.toMap(ImmutablePair::getLeft, ImmutablePair::getRight));
        return result;
    }

    @Override
    public void setParameter(Map<String, String> paraMap) {
    }

    public static void main(String[] args) throws FileNotFoundException {
        String modelPath = PerceptronClassifier.class.getResource("/").getPath() + "/perceptron_model.dat";
        System.out.println(modelPath);
        PerceptronClassifier perceptronClassifier = new PerceptronClassifier();
        ArrayList<Tuple> data = new ArrayList<Tuple>();
        data.add(new Tuple(1, new double[]{1.0, 0.0, 0.0}, String.valueOf(1)));
        data.add(new Tuple(2, new double[]{1.0, 0.0, 1.0}, String.valueOf(1)));
        data.add(new Tuple(3, new double[]{1.0, 1.0, 0.0}, String.valueOf(1)));
        data.add(new Tuple(4, new double[]{1.0, 1.0, 1.0}, String.valueOf(0)));
        perceptronClassifier.train(data);
        perceptronClassifier.model.persist(modelPath);
        perceptronClassifier = new PerceptronClassifier();
        perceptronClassifier.model.load(new FileInputStream(modelPath));
        Tuple test = new Tuple(5, new double[]{1.0, 1.0, 1.0}, null);
        System.out.println(perceptronClassifier.predict(test));
    }
}

