/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.randomForest;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.dromara.randomForest.DataTable;
import org.dromara.randomForest.Node;
import org.dromara.randomForest.TreeWithTrust;

public class Tree {
    private DataTable dataTable;
    private Map<String, List<Integer>> table;
    private Node rootNode;
    private List<Integer> endList;
    private final List<Node> lastNodes = new ArrayList<Node>();
    private final Random random = new Random();
    private final double trustPunishment;

    public Node getRootNode() {
        return this.rootNode;
    }

    public DataTable getDataTable() {
        return this.dataTable;
    }

    public void setRootNode(Node rootNode) {
        this.rootNode = rootNode;
    }

    public Tree(double trustPunishment) {
        this.trustPunishment = trustPunishment;
    }

    public Tree(DataTable dataTable, double trustPunishment) throws Exception {
        if (dataTable == null || dataTable.getKey() == null) {
            throw new Exception("dataTable is empty");
        }
        this.trustPunishment = trustPunishment;
        this.dataTable = dataTable;
    }

    private double log2(double p) {
        return Math.log(p) / Math.log(2.0);
    }

    private double getEnt(List<Integer> list) {
        HashMap<Integer, Integer> myType = new HashMap<Integer, Integer>();
        for (int index : list) {
            int type = this.endList.get(index);
            if (myType.containsKey(type)) {
                myType.put(type, (Integer)myType.get(type) + 1);
                continue;
            }
            myType.put(type, 1);
        }
        double ent = 0.0;
        for (Map.Entry entry1 : myType.entrySet()) {
            double g = (double)((Integer)entry1.getValue()).intValue() / (double)list.size();
            ent += g * this.log2(g);
        }
        return -ent;
    }

    private double getGain(double ent, double dNub, double gain) {
        return gain + ent * dNub;
    }

    private List<Node> createNode(Node node) {
        Set<String> attributes = node.attribute;
        List<Integer> fatherList = node.fatherList;
        if (!attributes.isEmpty()) {
            int j;
            HashMap mapAll = new HashMap();
            double fatherEnt = this.getEnt(fatherList);
            int fatherNub = fatherList.size();
            for (int i = 0; i < fatherList.size(); ++i) {
                int index = fatherList.get(i);
                for (String attr : attributes) {
                    int attrValue;
                    Map map;
                    if (!mapAll.containsKey(attr)) {
                        mapAll.put(attr, new HashMap());
                    }
                    if (!(map = (Map)mapAll.get(attr)).containsKey(attrValue = this.table.get(attr).get(index).intValue())) {
                        map.put(attrValue, new ArrayList());
                    }
                    List list = (List)map.get(attrValue);
                    list.add(index);
                }
            }
            HashMap nodeMap = new HashMap();
            int i = 0;
            double sigmaG = 0.0;
            HashMap<String, Gain> gainMap = new HashMap<String, Gain>();
            for (Map.Entry mapEntry : mapAll.entrySet()) {
                Map map = (Map)mapEntry.getValue();
                double gain = 0.0;
                double IV = 0.0;
                ArrayList<Node> nodeList = new ArrayList<Node>();
                String name = (String)mapEntry.getKey();
                nodeMap.put(name, nodeList);
                for (Map.Entry entry : map.entrySet()) {
                    List list;
                    Set<String> nowAttribute = this.removeAttribute(attributes, name);
                    Node sonNode = new Node();
                    nodeList.add(sonNode);
                    sonNode.attribute = nowAttribute;
                    sonNode.fatherList = list = (List)entry.getValue();
                    sonNode.typeId = (Integer)entry.getKey();
                    int myNub = list.size();
                    double ent = this.getEnt(list);
                    double dNub = (double)myNub / (double)fatherNub;
                    IV = dNub * this.log2(dNub) + IV;
                    gain = this.getGain(ent, dNub, gain);
                }
                Gain gain1 = new Gain();
                gainMap.put(name, gain1);
                gain1.gain = fatherEnt - gain;
                if (IV != 0.0) {
                    gain1.gainRatio = gain1.gain / -IV;
                } else {
                    gain1.gainRatio = 1000000.0;
                }
                sigmaG = gain1.gain + sigmaG;
                ++i;
            }
            double avgGain = sigmaG / (double)i;
            double gainRatio = -2.0;
            String key = null;
            for (Map.Entry entry : gainMap.entrySet()) {
                Gain gain = (Gain)entry.getValue();
                if (gainMap.size() != 1 && (!(gain.gain >= avgGain) && !(Math.abs(gain.gain - avgGain) < 1.0E-6) || !(gain.gainRatio >= gainRatio) && gainRatio != -2.0)) continue;
                gainRatio = gain.gainRatio;
                key = (String)entry.getKey();
            }
            node.key = key;
            List nodeList = (List)nodeMap.get(key);
            for (j = 0; j < nodeList.size(); ++j) {
                ((Node)nodeList.get((int)j)).fatherNode = node;
            }
            for (j = 0; j < nodeList.size(); ++j) {
                Node node1 = (Node)nodeList.get(j);
                node1.nodeList = this.createNode(node1);
            }
            return nodeList;
        }
        node.isEnd = true;
        node.type = this.getType(fatherList);
        this.lastNodes.add(node);
        return null;
    }

    private int getType(List<Integer> list) {
        HashMap<Integer, Integer> myType = new HashMap<Integer, Integer>();
        for (int index : list) {
            int type = this.endList.get(index);
            if (myType.containsKey(type)) {
                myType.put(type, (Integer)myType.get(type) + 1);
                continue;
            }
            myType.put(type, 1);
        }
        int type = 0;
        int nub = 0;
        for (Map.Entry entry : myType.entrySet()) {
            int nowNub = (Integer)entry.getValue();
            if (nowNub <= nub) continue;
            type = (Integer)entry.getKey();
            nub = nowNub;
        }
        return type;
    }

    private Set<String> removeAttribute(Set<String> attributes, String name) {
        HashSet<String> attriBute = new HashSet<String>();
        for (String myName : attributes) {
            if (myName.equals(name)) continue;
            attriBute.add(myName);
        }
        return attriBute;
    }

    private int getTypeId(Object ob, String name) throws Exception {
        Class<?> body = ob.getClass();
        String methodName = "get" + name.substring(0, 1).toUpperCase() + name.substring(1);
        Method method = body.getMethod(methodName, new Class[0]);
        return Integer.parseInt(method.invoke(ob, new Object[0]).toString());
    }

    public TreeWithTrust judge(Object ob) throws Exception {
        if (this.rootNode != null) {
            TreeWithTrust treeWithTrust = new TreeWithTrust();
            treeWithTrust.setTrust(1.0);
            this.goTree(ob, this.rootNode, treeWithTrust, 0);
            return treeWithTrust;
        }
        throw new Exception("rootNode is null");
    }

    private void punishment(TreeWithTrust treeWithTrust) {
        double trust = treeWithTrust.getTrust();
        treeWithTrust.setTrust(trust *= this.trustPunishment);
    }

    private void goTree(Object ob, Node node, TreeWithTrust treeWithTrust, int times) throws Exception {
        if (!node.isEnd) {
            int myType = this.getTypeId(ob, node.key);
            if (myType == 0) {
                this.punishment(treeWithTrust);
            }
            List<Node> nodeList = node.nodeList;
            boolean isOk = false;
            for (Node testNode : nodeList) {
                if (testNode.typeId != myType) continue;
                isOk = true;
                node = testNode;
                break;
            }
            if (!isOk) {
                this.punishment(treeWithTrust);
                int index = this.random.nextInt(nodeList.size());
                node = nodeList.get(index);
            }
            this.goTree(ob, node, treeWithTrust, ++times);
        } else {
            if (node.typeId == 0) {
                int nub = this.rootNode.attribute.size() - times;
                for (int i = 0; i < nub; ++i) {
                    this.punishment(treeWithTrust);
                }
            }
            treeWithTrust.setType(node.type);
        }
    }

    public void study() throws Exception {
        if (this.dataTable != null && this.dataTable.getLength() > 0) {
            this.rootNode = new Node();
            this.table = this.dataTable.getTable();
            this.endList = this.dataTable.getTable().get(this.dataTable.getKey());
            Set<String> set = this.dataTable.getKeyType();
            set.remove(this.dataTable.getKey());
            this.rootNode.attribute = set;
            ArrayList<Integer> list = new ArrayList<Integer>();
            for (int i = 0; i < this.endList.size(); ++i) {
                list.add(i);
            }
            this.rootNode.fatherList = list;
            this.rootNode.nodeList = this.createNode(this.rootNode);
            for (Node lastNode : this.lastNodes) {
                this.prune(lastNode.fatherNode);
            }
        } else {
            throw new Exception("dataTable is null");
        }
        this.lastNodes.clear();
    }

    private void prune(Node node) {
        List<Node> listNode;
        if (node != null && !node.isEnd && this.isPrune(node, listNode = node.nodeList)) {
            this.deduction(node);
            this.prune(node.fatherNode);
        }
    }

    private void deduction(Node node) {
        node.isEnd = true;
        node.nodeList = null;
        node.type = this.getType(node.fatherList);
    }

    private boolean isPrune(Node father, List<Node> sonNodes) {
        boolean isRemove = false;
        ArrayList<Integer> typeList = new ArrayList<Integer>();
        for (int i = 0; i < sonNodes.size(); ++i) {
            Node node = sonNodes.get(i);
            List<Integer> list = node.fatherList;
            typeList.add(this.getType(list));
        }
        int fatherType = this.getType(father.fatherList);
        int nub = this.getRightPoint(father.fatherList, fatherType);
        double rightFather = (double)nub / (double)father.fatherList.size();
        int rightNub = 0;
        int rightAllNub = 0;
        for (int i = 0; i < sonNodes.size(); ++i) {
            Node node = sonNodes.get(i);
            List<Integer> list = node.fatherList;
            int right = this.getRightPoint(list, (Integer)typeList.get(i));
            rightNub += right;
            rightAllNub += list.size();
        }
        double rightPoint = (double)rightNub / (double)rightAllNub;
        if (rightPoint <= rightFather) {
            isRemove = true;
        }
        return isRemove;
    }

    private int getRightPoint(List<Integer> types, int type) {
        int nub = 0;
        for (int index : types) {
            int end = this.endList.get(index);
            if (end != type) continue;
            ++nub;
        }
        return nub;
    }

    private static class Gain {
        private double gain;
        private double gainRatio;

        private Gain() {
        }
    }
}

