/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.classifier.naivebayes;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import java.util.stream.Collectors;
import org.apache.commons.lang3.NotImplementedException;
import org.maochen.nlp.classifier.IClassifier;
import org.maochen.nlp.classifier.naivebayes.NBTrainingEngine;
import org.maochen.nlp.classifier.naivebayes.NaiveBayesModel;
import org.maochen.nlp.datastructure.Tuple;
import org.maochen.nlp.utils.VectorUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NaiveBayesClassifier
implements IClassifier {
    private static final Logger LOG = LoggerFactory.getLogger(NaiveBayesClassifier.class);
    private NaiveBayesModel model;

    public NaiveBayesClassifier(InputStream modelInputStream) {
        this.model = new NaiveBayesModel();
        this.model.load(modelInputStream);
    }

    public NaiveBayesClassifier() {
    }

    @Override
    public void setParameter(Map<String, String> paraMap) {
        throw new NotImplementedException("not implemented");
    }

    @Override
    public Map<String, Double> predict(Tuple predict) {
        HashMap<Integer, Double> labelProb = new HashMap<Integer, Double>();
        for (Integer labelIndex : this.model.labelIndexer.getIndexSet()) {
            double likelihood = 1.0;
            for (int i = 0; i < predict.featureVector.length; ++i) {
                double fi = predict.featureVector[i];
                likelihood *= VectorUtils.gaussianPDF(this.model.meanVectors[labelIndex][i], this.model.varianceVectors[labelIndex][i], fi);
            }
            double posterior = this.model.labelPrior.get(labelIndex) * likelihood;
            labelProb.put(labelIndex, posterior);
        }
        double evidence = labelProb.values().stream().reduce((e1, e2) -> e1 + e2).orElse(-1.0);
        if (evidence == -1.0) {
            LOG.error("Evidence is Empty!");
            return new HashMap<String, Double>();
        }
        labelProb.entrySet().forEach(entry -> {
            double prob = (Double)entry.getValue() / evidence;
            if (prob > 0.999) {
                prob = 1.0;
            } else if (prob < 0.001) {
                prob = 0.0;
            }
            labelProb.put((Integer)entry.getKey(), prob);
        });
        Map<String, Double> result = this.model.labelIndexer.convertMapKey(labelProb);
        if (predict.label == null || predict.label.isEmpty()) {
            predict.label = result.entrySet().stream().max((e1, e2) -> ((Double)e1.getValue()).compareTo((Double)e2.getValue())).map(Map.Entry::getKey).orElse("");
        }
        return result;
    }

    public String predictLabel(Tuple predict) {
        Map<String, Double> prob = this.predict(predict);
        return prob.entrySet().stream().max((x1, x2) -> ((Double)x1.getValue()).compareTo((Double)x2.getValue())).map(Map.Entry::getKey).orElse(null);
    }

    @Override
    public IClassifier train(List<Tuple> trainingData) {
        this.model = new NBTrainingEngine(trainingData).train();
        return this;
    }

    public void persistModel(String filename) {
        if (this.model != null) {
            this.model.persist(filename);
        }
    }

    public void loadModel(String filename) {
        this.model = new NaiveBayesModel();
        try {
            this.model.load(new FileInputStream(filename));
        }
        catch (FileNotFoundException e) {
            e.printStackTrace();
        }
    }

    public static List<Tuple> readTrainingData(String filename, String delimiter) {
        ArrayList<Tuple> data = new ArrayList<Tuple>();
        List<Object> trainingDataString = new ArrayList();
        try {
            Throwable throwable = null;
            try (BufferedReader br = new BufferedReader(new FileReader(filename));){
                String line = br.readLine();
                while (line != null) {
                    trainingDataString.add(line.trim());
                    line = br.readLine();
                }
            }
            catch (Throwable line) {
                Throwable throwable2 = line;
                throw line;
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        trainingDataString = trainingDataString.parallelStream().distinct().collect(Collectors.toList());
        for (String string : trainingDataString) {
            String[] tokens = string.trim().split(delimiter);
            String label = tokens[0];
            double[] values = new double[tokens.length - 1];
            for (int i = 1; i < tokens.length; ++i) {
                String tokenString = tokens[i].contains(":") ? tokens[i].split(":")[1] : tokens[i];
                values[i - 1] = Double.parseDouble(tokenString);
            }
            data.add(new Tuple(0, values, label));
        }
        return data;
    }

    public static void writeToFile(List<Tuple> dataset, String filename) {
        try (BufferedWriter output = new BufferedWriter(new FileWriter(new File(filename)));){
            for (Tuple t : dataset) {
                output.write(t.toString() + System.lineSeparator());
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static void splitData(String originalTrainingDataFile) {
        int lastTrainingDataSize;
        List<Tuple> trainingData = NaiveBayesClassifier.readTrainingData(originalTrainingDataFile, "\\s");
        ArrayList<Tuple> wrongData = new ArrayList<Tuple>();
        int iterCount = 0;
        do {
            System.out.println("Iteration:\t" + ++iterCount);
            lastTrainingDataSize = trainingData.size();
            NaiveBayesClassifier nbc = new NaiveBayesClassifier();
            nbc.train(trainingData);
            Iterator<Tuple> trainingDataIter = trainingData.iterator();
            while (trainingDataIter.hasNext()) {
                Tuple t = trainingDataIter.next();
                String actual = nbc.predictLabel(t);
                if (t.label.equals(actual) || t.label.equals("1")) continue;
                wrongData.add(t);
                trainingDataIter.remove();
            }
            Iterator wrongDataIter = wrongData.iterator();
            while (wrongDataIter.hasNext()) {
                Tuple t = (Tuple)wrongDataIter.next();
                String actual = nbc.predictLabel(t);
                if (!t.label.equals(actual)) continue;
                trainingData.add(t);
                wrongDataIter.remove();
            }
        } while (trainingData.size() != lastTrainingDataSize);
        NaiveBayesClassifier.writeToFile(trainingData, originalTrainingDataFile + ".aligned");
        NaiveBayesClassifier.writeToFile(wrongData, originalTrainingDataFile + ".wrong");
    }

    public static void main(String[] args) {
        String folder = "/Users/Maochen/Desktop/w2v_weight_training/";
        String outputModelFolder = "/Users/Maochen/workspace/amelia/eliza-ir/src/main/resources/";
        NaiveBayesClassifier nbc = new NaiveBayesClassifier();
        List<Tuple> trainingData = NaiveBayesClassifier.readTrainingData(folder + "/training.all.txt.aligned", "\\s");
        nbc.train(trainingData);
        nbc.persistModel(outputModelFolder + "/nb_model.dat");
        nbc.loadModel(outputModelFolder + "/nb_model.dat");
        Scanner scan = new Scanner(System.in);
        String input = "";
        String quitRegex = "q|quit|exit";
        while (!input.matches(quitRegex)) {
            System.out.println("Please enter feats:");
            input = scan.nextLine();
            if (input.trim().isEmpty() || input.matches(quitRegex)) continue;
            double[] feats = Arrays.stream(input.split("\\s")).mapToDouble(Double::parseDouble).toArray();
            Map<String, Double> results = nbc.predict(new Tuple(feats));
            System.out.println(results);
        }
    }

    public static void main1(String[] args) {
        NaiveBayesClassifier nbc = new NaiveBayesClassifier();
        ArrayList<Tuple> trainingData = new ArrayList<Tuple>();
        trainingData.add(new Tuple(1, new double[]{6.0, 180.0, 12.0}, "male"));
        trainingData.add(new Tuple(2, new double[]{5.92, 190.0, 11.0}, "male"));
        trainingData.add(new Tuple(3, new double[]{5.58, 170.0, 12.0}, "male"));
        trainingData.add(new Tuple(4, new double[]{5.92, 165.0, 10.0}, "male"));
        trainingData.add(new Tuple(5, new double[]{5.0, 100.0, 6.0}, "female"));
        trainingData.add(new Tuple(6, new double[]{5.5, 150.0, 8.0}, "female"));
        trainingData.add(new Tuple(7, new double[]{5.42, 130.0, 7.0}, "female"));
        trainingData.add(new Tuple(8, new double[]{5.75, 150.0, 9.0}, "female"));
        Tuple predict = new Tuple(new double[]{6.0, 130.0, 8.0});
        nbc.train(trainingData);
        Map<String, Double> probs = nbc.predict(predict);
        ArrayList result = new ArrayList();
        Comparator<Map.Entry> reverseCmp = Collections.reverseOrder(Comparator.comparing(Map.Entry::getValue));
        probs.entrySet().stream().sorted(reverseCmp).forEach(result::add);
        System.out.println("Result: " + predict);
        result.forEach(e -> System.out.println((String)e.getKey() + "\t:\t" + e.getValue()));
    }
}

