package cn.tyoui.nbc;

import cn.tyoui.dt.DT;
import org.apache.log4j.Logger;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 朴素贝叶斯分类器
 *
 * @author Tyoui
 * @version 1.8.1
 */
public class NBC {

    private Logger logger = Logger.getLogger(NBC.class);
    private String[][] data;
    private Integer INDEX;
    private Map<Integer, Double> continual = new HashMap<>();


    public NBC() {
    }

    /**
     * 重新特征值行数
     *
     * @param INDEX 特征值行数
     */
    public NBC(Integer INDEX) {
        this.INDEX = INDEX;
    }

    /**
     * 获取文本数据,存储在二维数组
     *
     * @param file 文本路径
     * @return 返回数据
     * @throws Exception 读取失败
     */
    public String[][] readArrayFile(String file) throws Exception {
        DT dt = new DT();
        List<String[]> list = dt.getDataFile(file);
        data = dt.listArrays(list);
        INDEX = data[0].length - 1;
        return data;
    }

    /**
     * 得到特征值，记住默认的特征行数在最后一排
     *
     * @param index 根据行数统计个数
     * @return 特征值的个数
     */
    public Map<String, Integer> eigenvalue(int index) {
        Map<String, Integer> map = new HashMap<>();
        for (int i = 0; i < data.length; i++) {
            String key = data[i][index];
            if (map.get(key) == null)
                map.put(key, 1);
            else
                map.replace(key, map.get(key) + 1);
        }
        return map;
    }


    /**
     * 特征值对应的一行分组概率
     *
     * @param key   特征值
     * @param value 特征值的个数
     * @param row   那一个行
     * @return 哪一行对应的特征值概率
     */
    public Map<String, Double> numerator(String key, Integer value, int row) {
        Map<String, Double> map = new HashMap<>();
        double average = 0.0, variance = 0.0;
        for (int i = 0; i < data.length; i++) {
            if (data[i][INDEX].equals(key)) {
                String eigenvalue = data[i][row];
                if (continual.get(row) != null) {
                    average += Double.parseDouble(eigenvalue);
                } else {
                    if (map.get(eigenvalue) == null)
                        map.put(eigenvalue, 1.0 / value);
                    else
                        map.replace(eigenvalue, map.get(eigenvalue) + 1.0 / value);
                }
            }
        }
        if (continual.get(row) != null) {
            average = average / value;
            for (int i = 0; i < data.length; i++)
                if (data[i][INDEX].equals(key))
                    variance += Math.pow(Double.parseDouble(data[i][row]) - average, 2);
            variance = variance / value;
            map.put("" + continual.get(row), ND(average, variance, continual.get(row)));
        }
        return map;
    }

    /**
     * 每行数据分组的总概率
     *
     * @param row 每一行排号
     * @return 该分组下的概率
     */
    public Map<String, Double> denominator(int row) {
        double average = 0.0, variance = 0.0;
        Map<String, Double> map = new HashMap<>();
        for (int i = 0; i < data.length; i++) {
            String eigenvalue = data[i][row];
            if (continual.get(row) != null) {
                average += Double.parseDouble(data[i][row]);
            } else {
                if (map.get(eigenvalue) == null)
                    map.put(eigenvalue, 1.0 / data.length);
                else
                    map.replace(eigenvalue, map.get(eigenvalue) + 1.0 / data.length);
            }
        }
        if (continual.get(row) != null) {
            average = average / data.length;
            for (int i = 0; i < data.length; i++)
                variance += Math.pow(Double.parseDouble(data[i][row]) - average, 2);
            variance = variance / data.length;
            map.put("" + continual.get(row), ND(average, variance, continual.get(row)));
        }
        return map;
    }

    /**
     * 执行方法
     *
     * @param print 一系列问题
     * @return 返回回答概率
     */
    public double start(String... print) {
        if (print == null) {
            logger.error("输入参数为空");
            return 0.0;
        }
        String key = print[INDEX];
        double eigenvalueAll;
        double numeratorAll = 1.0;
        double denominatorAll = 1.0;
        Map<String, Integer> map = eigenvalue(INDEX);
        if (map.get(key) == null) {
            logger.error("特征值找不到。输入有问题!");
            return 0.0;
        }
        eigenvalueAll = (map.get(key) * 1.0) / data.length;
        for (int row = 0; row < data[0].length - 1; row++) {
            Map<String, Double> numerator = numerator(key, map.get(key), row);
            Map<String, Double> denominator = denominator(row);
            denominatorAll *= denominator.get(print[row]);
            Double n = numerator.get(print[row]);
            if (n == null) {
                logger.error("前提条件错误!");
                return 0.0;
            }
            numeratorAll *= n;
        }
        return (numeratorAll * eigenvalueAll) / denominatorAll;
    }

    /**
     * 取出特征值
     *
     * @return 特征值集合
     */
    public List<String> eigenvalue() {
        List<String> list = new ArrayList<>();
        Map<String, Integer> eigenvalue = eigenvalue(INDEX);
        for (String key : eigenvalue.keySet())
            list.add(key);
        return list;
    }

    /**
     * 设置该行数据为连续型数据
     *
     * @param row   连续型数据行
     * @param value 该连续型的值
     */
    public void setRowContinual(Integer row, Double value) {
        continual.put(row, value);
    }


    /**
     * 计算正态分布
     *
     * @param average  平均数
     * @param variance 方差
     * @param x        未知参数
     * @return 未知参数的概率
     */
    private double ND(double average, double variance, double x) {
        double foot = 1 / Math.sqrt(2 * Math.PI * variance);
        double head = Math.pow(average - x, 2) / (2 * variance);
        double pro = foot * Math.pow(Math.E, -head);
        return pro;
    }
}
