package cn.tyoui.dt;

import cn.tyoui.exception.LogException;
import org.apache.log4j.Logger;

import java.io.*;
import java.util.*;

/**
 * 决策树
 * 注意。该决策树主要针对离散数据
 * 不满足正态分布,并且特征值必须放在最后一排
 *
 * @author Tyoui
 * @version 1.8.1
 */
public class DT {

    private Logger logger = Logger.getLogger(DT.class);
    private String[][] data;
    private ID3 id3 = new ID3();
    private int INDEX;

    /**
     * 获取文本数据
     *
     * @param file 文本路径
     * @return 文本二维数组
     * @throws Exception 读取失败
     */
    public List<String[]> getDataFile(String file) throws Exception {
        List<String[]> list = new ArrayList<>();
        InputStream inputStream = this.getClass().getClassLoader().getResourceAsStream(file);
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream, "UTF-8"));
        while (bufferedReader.ready()) {
            String line[] = bufferedReader.readLine().split("\t");
            list.add(line);
        }
        bufferedReader.close();
        inputStream.close();
        return list;
    }

    /**
     * 获取每组分类的个数
     *
     * @param arrays 二维文本数组
     * @return 分类个数
     */
    public Map<Integer, Map<String, Integer[]>> category(String[][] arrays) {
        Map<Integer, Map<String, Integer[]>> mapMap = new HashMap<>();
        int horizontal = arrays.length;
        int vertical = arrays[1].length;
        for (int j = 0; j < vertical; j++) {
            Map<String, Map<String, Integer>> categoryMap = new HashMap<>();
            Map<String, Integer[]> map = new HashMap<>();
            Map<String, Integer> m;
            for (int i = 0; i < horizontal; i++) {
                String key = arrays[i][j];
                if (j == INDEX) {
                    if (map.get(key) == null) {
                        map.put(key, new Integer[]{1});
                    } else {
                        Integer[] values = map.get(key);
                        values[0]++;
                        map.put(key, values);
                    }
                } else {
                    if (categoryMap.get(key) == null) {
                        m = new HashMap<>();
                        m.put(arrays[i][INDEX], 1);
                        categoryMap.put(key, m);
                    } else {
                        m = categoryMap.get(key);
                        if (m.containsKey(arrays[i][INDEX])) {
                            m.replace(arrays[i][INDEX], m.get(arrays[i][INDEX]) + 1);
                        } else {
                            m.put(arrays[i][INDEX], 1);
                        }
                        categoryMap.replace(key, m);
                    }
                }
            }
            if (j != INDEX) {
                map = new HashMap<>();
                for (String key : categoryMap.keySet()) {
                    Map<String, Integer> integers = categoryMap.get(key);
                    Integer[] values = new Integer[integers.size()];
                    int index = 0;
                    for (String k : integers.keySet()) {
                        values[index++] = integers.get(k);
                    }
                    map.put(key, values);
                }
            }
            mapMap.put(j, map);
        }
        return mapMap;
    }


    /**
     * 将链表转换二维数组，方便操作
     *
     * @param list 文本链表
     * @return 二维数组
     */
    public String[][] listArrays(List<String[]> list) {
        String[][] data = new String[list.size()][];
        for (int i = 0; i < list.size(); i++)
            data[i] = list.get(i);
        return data;
    }


    /**
     * 聚合
     *
     * @return 每一分类的信息熵
     * @throws LogException 处理逻辑异常
     */
    public Map<Integer, Double> group() throws LogException {
        //读取文件后。进行每一行分类。键是每一行的索引。值的对应分类的个数
        Map<Integer, Map<String, Integer[]>> category = category(data);

        //得到训练集D的不同个数
        Map<String, Integer[]> infoMap = category.get(INDEX);

        //训练集D的期望信息
        Integer INFO_D[] = new Integer[infoMap.size()];
        int index = 0;
        for (Integer[] integers : infoMap.values())
            INFO_D[index++] = integers[0];
        double info = id3.Info_D(INFO_D);

        //计算其他分类信息熵
        Map<Integer, Double> map = new HashMap<>();
        for (Integer key : category.keySet()) {
            double entropy = 0.0;
            if (key == INDEX)
                continue;
            Map<String, Integer[]> integers = category.get(key);
            for (Integer[] values : integers.values())
                entropy += id3.Info_A(data.length, values);
            entropy = id3.Gain(info, entropy);
            map.put(key, entropy);
        }
        return map;
    }

    /**
     * 读取文本
     *
     * @param file 文本路径
     */
    public void readerData(String file) {
        try {
            data = listArrays(getDataFile(file));
            INDEX = data[0].length - 1;
        } catch (Exception e) {
            logger.error("读取文件异常");
        }
    }

    /**
     * 计算最的信息熵在那个分组
     *
     * @param map 分组与信息熵的集合
     * @return 分组
     */
    public Integer maxEntropy(Map<Integer, Double> map) {
        double Max = 0.0;
        Integer flag = 0;
        for (Integer key : map.keySet()) {
            double value = map.get(key);
            if (value > Max) {
                Max = value;
                flag = key;
            }
        }
        return flag;
    }


    /**
     * 获取每一个分组的不同个数
     *
     * @param head 那个分组
     * @return 分组不同个数
     */
    public Map<String, Integer> treeNode(Integer head) {
        Map<String, Integer> values = new HashMap<>();
        for (int i = 0; i < data.length; i++) {
            String key = data[i][head];
            if (values.get(key) == null) {
                values.put(key, 1);
            } else
                values.replace(key, values.get(key) + 1);
        }
        return values;
    }


    /**
     * 根据当前节点来创建下一个节点树
     *
     * @param head 当前节点
     * @return 以节点名作为唯一键的数据集合
     */
    public Map<String, String[][]> createTree(Integer head) {
        Map<String, String[][]> map = new HashMap<>();
        Map<String, Integer> values = treeNode(head);
        String[][] s;
        for (int i = 0; i < data.length; i++) {
            String key = data[i][head];
            if (map.get(key) == null) {
                s = new String[values.get(key)][data[i].length - 1];
                for (int index = 0, j = 0; j < data[i].length; j++) {
                    if (j != head) {
                        s[0][index] = data[i][j];
                        index++;
                    }
                }
            } else {
                s = map.get(key);
                Integer row = 0;
                int flag = head == 0 ? head + 1 : head - 1;
                for (String[] k : s) {
                    if (k[flag] != null)
                        row++;
                }
                for (int index = 0, j = 0; j < data[i].length; j++) {
                    if (j != head) {
                        s[row][index] = data[i][j];
                        index++;
                    }
                }
            }
            map.put(key, s);
        }
        return map;
    }

    /**
     * 打印数据
     *
     * @param data 数据
     */
    public void printData(String[][] data) {
        for (String[] str : data) {
            for (String value : str)
                System.out.print("\t" + value + "\t");
            System.out.println();
        }
    }

    /**
     * 重写数据
     *
     * @param data 文本数据
     */
    public void rewriteData(String[][] data) {
        this.data = data;
        this.INDEX = data[0].length - 1;
    }


    /**
     * 循环创建树，遍历方式也前序遍历
     *
     * @param tree 新建一个树
     * @throws LogException 分析信息熵错误
     */
    public void loopTree(Tree tree) throws LogException {
        Map<Integer, Double> group = group();
        Integer head = maxEntropy(group);
        Map<String, String[][]> map = createTree(head);
        for (String key : map.keySet()) {
            Tree tr = new Tree();
            tr.attribute = key;
            tr.data = map.get(key);
            tree.children.add(tr);
            if (tr.data.length == 1)
                continue;
            rewriteData(tr.data);
            loopTree(tree);
        }
    }

    /**
     * 将数据写入到文本
     *
     * @param tree     数据树
     * @param filePath 保存数据的文件路径
     */
    public void writeTree(Tree tree, String filePath) {
        File file = new File(filePath);
        Writer writer;
        try {
            writer = new FileWriter(file, true);
        } catch (IOException e) {
            logger.error("文件名有问题或者改文件不可打开");
            return;
        }
        for (Tree trees : tree.children) {
            StringBuilder text = new StringBuilder();
            String[][] data = trees.data;
            text.append("属性:" + trees.attribute + "\n");
            for (String[] str : data) {
                for (String value : str)
                    text.append("\t" + value + "\t");
                text.append("\n");
            }
            try {
                writer.write(text.toString());
            } catch (IOException e) {
                logger.error("文件写入异常");
            }
        }
        try {
            writer.close();
        } catch (IOException e) {
            logger.error("文件关闭失败");
        }
    }
}
