package cn.tyoui.dt;

import cn.tyoui.exception.LogException;
import org.apache.log4j.Logger;

public class ID3 {

    private Logger logger = Logger.getLogger(ID3.class);

    /**
     * 如果a的x次方等于N（a大于0，且a不等于1）那么数x叫做以a为底N的对数:记作x=logaN。
     * logaN=logN/loga
     *
     * @param a 对数的底数
     * @param N 对数的真数
     * @return 返回结果
     * @throws LogException 逻辑错误
     */
    public double log(double a, double N) throws LogException {
        if (a <= 0 || a == 1) {
            logger.error("底数不能小于等于0或等于1");
            throw new LogException("底数不能小于等于0或等于1");
        }
        return Math.log(N) / Math.log(a);
    }

    /**
     * 同上
     *
     * @param N 真数
     * @return 返回结果
     * @throws LogException 逻辑错误
     */
    private double log(double N) throws LogException {
        return log(2.0, N);
    }


    /**
     * 返回信息熵的期望信息
     *
     * @param N D中任意元组的非零概率
     * @return 信息熵
     * @throws LogException log异常
     */
    private double Info_D(double N) throws LogException {
        if (N == 0 || N == 1)
            return 0.0;
        return -N * log(N);
    }

    /**
     * 基尼不纯度
     *
     * @param info_D 信息熵的期望值
     * @param info_A 分类所需要的期望信息
     * @return 信息增益值
     */
    public double Gain(double info_D, double info_A) {
        return info_D - info_A;
    }

    /**
     * 计算对训练集D进行分类所需要的期望信息
     *
     * @param m 分类的个数
     * @return 训练集D的期望
     * @throws LogException log异常
     */
    public double Info_D(Integer... m) throws LogException {
        double all = 0.0;
        for (int num : m)
            all += num;
        double d = 0.0;
        for (int num : m) {
            d += Info_D(num / all);
        }
        return d;
    }


    /**
     * 计算不同属性的期望信息需求
     *
     * @param all 分类总和
     * @param m   各分类的数量
     * @return 数学期望
     * @throws LogException log异常
     */
    public double Info_A(int all, Integer... m) throws LogException {
        double all_m = 0.0;
        for (int num : m)
            all_m += num;
        double d = 0.0;
        for (int num : m) {
            d += Info_D(num / all_m);
        }
        return d * (all_m / all);
    }
}
