C4.5算法建立决策树JAVA实现

转载连接:http://www.cnblogs.com/lixusign/archive/2012/06/13/2548124.html

当前的属性为:age income student credit_rating

当前的数据集为(最后一列是TARGET_VALUE):

---------------------------------

youth     high   no   fair      no 
youth     high   no   excellent   no 
middle_aged   high   no   fair     yes 
senior     low    yes  fair     yes 
senior     low    yes  excellent   no 
middle_aged   low    yes  excellent   yes 
youth     medium  no   fair     no 
youth     low     yes  fair     yes 
senior     medium  yes    fair     yes 
youth     medium  yes    excellent   yes 
middle_aged   high   yes  fair        yes 
senior     medium  no     excellent   no 
---------------------------------

C4.5建立树类

复制代码
package C45Test;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class DecisionTree {

    public TreeNode createDT(List<ArrayList<String>> data,List<String> attributeList){
        
        System.out.println("当前的DATA为");
        for(int i=0;i<data.size();i++){
            ArrayList<String> temp = data.get(i);
            for(int j=0;j<temp.size();j++){
                System.out.print(temp.get(j)+ " ");
            }
            System.out.println();
        }
        System.out.println("---------------------------------");
        System.out.println("当前的ATTR为");
        for(int i=0;i<attributeList.size();i++){
            System.out.print(attributeList.get(i)+ " ");
        }
        System.out.println();
        System.out.println("---------------------------------");
        TreeNode node = new TreeNode();
        String result = InfoGain.IsPure(InfoGain.getTarget(data));
        if(result != null){
            node.setNodeName("leafNode");
            node.setTargetFunValue(result);
            return node;
        }
        if(attributeList.size() == 0){
            node.setTargetFunValue(result);
            return node;
        }else{
            InfoGain gain = new InfoGain(data,attributeList);
            double maxGain = 0.0;
            int attrIndex = -1;
            for(int i=0;i<attributeList.size();i++){
                double tempGain = gain.getGainRatio(i);
                if(maxGain < tempGain){
                    maxGain = tempGain;
                    attrIndex = i;
                }
            }
            System.out.println("选择出的最大增益率属性为: " + attributeList.get(attrIndex));
            node.setAttributeValue(attributeList.get(attrIndex));
            List<ArrayList<String>> resultData = null;
            Map<String,Long> attrvalueMap = gain.getAttributeValue(attrIndex);
            for(Map.Entry<String, Long> entry : attrvalueMap.entrySet()){
                resultData = gain.getData4Value(entry.getKey(), attrIndex);
                TreeNode leafNode = null;
                System.out.println("当前为"+attributeList.get(attrIndex)+"的"+entry.getKey()+"分支。");
                if(resultData.size() == 0){
                    leafNode = new TreeNode();
                    leafNode.setNodeName(attributeList.get(attrIndex));
                    leafNode.setTargetFunValue(result);
                    leafNode.setAttributeValue(entry.getKey());
                }else{
                    for (int j = 0; j < resultData.size(); j++) {
                        resultData.get(j).remove(attrIndex);
                    }
                    ArrayList<String> resultAttr = new ArrayList<String>(attributeList);
                    resultAttr.remove(attrIndex);
                    leafNode = createDT(resultData,resultAttr);
                }
                node.getChildTreeNode().add(leafNode);
                node.getPathName().add(entry.getKey());
            }
        }
        return node;
    }
    
    class TreeNode{
        
        private String attributeValue;
        private List<TreeNode> childTreeNode;
        private List<String> pathName;
        private String targetFunValue;
        private String nodeName;
        
        public TreeNode(String nodeName){
            
            this.nodeName = nodeName;
            this.childTreeNode = new ArrayList<TreeNode>();
            this.pathName = new ArrayList<String>();
        }
        
        public TreeNode(){
            this.childTreeNode = new ArrayList<TreeNode>();
            this.pathName = new ArrayList<String>();
        }

        public String getAttributeValue() {
            return attributeValue;
        }

        public void setAttributeValue(String attributeValue) {
            this.attributeValue = attributeValue;
        }

        public List<TreeNode> getChildTreeNode() {
            return childTreeNode;
        }

        public void setChildTreeNode(List<TreeNode> childTreeNode) {
            this.childTreeNode = childTreeNode;
        }

        public String getTargetFunValue() {
            return targetFunValue;
        }

        public void setTargetFunValue(String targetFunValue) {
            this.targetFunValue = targetFunValue;
        }

        public String getNodeName() {
            return nodeName;
        }

        public void setNodeName(String nodeName) {
            this.nodeName = nodeName;
        }

        public List<String> getPathName() {
            return pathName;
        }

        public void setPathName(List<String> pathName) {
            this.pathName = pathName;
        }
        
    }
}
复制代码

 

 

增益率计算类(取log的时候底用的是e,没用2

复制代码
package C45Test;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

//C 4.5 实现
public class InfoGain {
    
    private List<ArrayList<String>> data;
    private List<String> attribute;
    
    public InfoGain(List<ArrayList<String>> data,List<String> attribute){
        
        this.data = new ArrayList<ArrayList<String>>();
        for(int i=0;i<data.size();i++){
            List<String> temp = data.get(i);
            ArrayList<String> t = new ArrayList<String>();
            for(int j=0;j<temp.size();j++){
                t.add(temp.get(j));
            }
            this.data.add(t);
        }
        
        this.attribute = new ArrayList<String>();
        for(int k=0;k<attribute.size();k++){
            this.attribute.add(attribute.get(k));
        }
        /*this.data = data;
        this.attribute = attribute;*/
    }
    
    //获得熵
    public double getEntropy(){
        
        Map<String,Long> targetValueMap = getTargetValue();
        Set<String> targetkey = targetValueMap.keySet();
        double entropy = 0.0;
        for(String key : targetkey){
            double p = MathUtils.div((double)targetValueMap.get(key), (double)data.size());
            entropy += (-1) * p * Math.log(p);
        }
        return entropy;
    }
    
    //获得InfoA
    public double getInfoAttribute(int attributeIndex){
        
        Map<String,Long> attributeValueMap = getAttributeValue(attributeIndex);
        double infoA = 0.0;
        for(Map.Entry<String, Long> entry : attributeValueMap.entrySet()){
            int size = data.size();
            double attributeP = MathUtils.div((double)entry.getValue() , (double) size);
            Map<String,Long> targetValueMap = getAttributeValueTargetValue(entry.getKey(),attributeIndex);
            long totalCount = 0L;
            for(Map.Entry<String, Long> entryValue :targetValueMap.entrySet()){
                totalCount += entryValue.getValue(); 
            }
            double valueSum = 0.0;
            for(Map.Entry<String, Long> entryTargetValue : targetValueMap.entrySet()){
                 double p = MathUtils.div((double)entryTargetValue.getValue(), (double)totalCount);
                 valueSum += Math.log(p) * p;
            }
            infoA += (-1) * attributeP * valueSum;
        }
        return infoA;
        
    }
    
    //得到属性值在决策空间的比例
    public Map<String,Long> getAttributeValueTargetValue(String attributeName,int attributeIndex){
        
        Map<String,Long> targetValueMap = new HashMap<String,Long>();
        Iterator<ArrayList<String>> iterator = data.iterator();
        while(iterator.hasNext()){
            List<String> tempList = iterator.next();
            if(attributeName.equalsIgnoreCase(tempList.get(attributeIndex))){
                int size = tempList.size();
                String key = tempList.get(size - 1);
                Long value = targetValueMap.get(key);
                targetValueMap.put(key, value != null ? ++value :1L);
            }
        }
        return targetValueMap;
    }
    
    //得到属性在决策空间上的数量
    public Map<String,Long> getAttributeValue(int attributeIndex){
        
        Map<String,Long> attributeValueMap = new HashMap<String,Long>();
        for(ArrayList<String> note : data){
            String key = note.get(attributeIndex);
            Long value = attributeValueMap.get(key);
            attributeValueMap.put(key, value != null ? ++value :1L);
        }
        return attributeValueMap;
        
    }
    
    public List<ArrayList<String>> getData4Value(String attrValue,int attrIndex){
        
        List<ArrayList<String>> resultData = new ArrayList<ArrayList<String>>();
        Iterator<ArrayList<String>> iterator = data.iterator();
        for(;iterator.hasNext();){
            ArrayList<String> templist = iterator.next();
            if(templist.get(attrIndex).equalsIgnoreCase(attrValue)){
                ArrayList<String> temp = (ArrayList<String>) templist.clone();
                resultData.add(temp);
            }
        }
        return resultData;
    }
    
    //获得增益率
    public double getGainRatio(int attributeIndex){
        return MathUtils.div(getGain(attributeIndex), getSplitInfo(attributeIndex));
    }
    
    //获得增益量
    public double getGain(int attributeIndex){
        return getEntropy() - getInfoAttribute(attributeIndex);
    }
    
    //得到惩罚因子
    public double getSplitInfo(int attributeIndex){
        
        Map<String,Long> attributeValueMap = getAttributeValue(attributeIndex);
        double splitA = 0.0;
        for(Map.Entry<String, Long> entry : attributeValueMap.entrySet()){
            int size = data.size();
            double attributeP = MathUtils.div((double)entry.getValue() , (double) size);
            splitA += attributeP * Math.log(attributeP) * (-1);
        }
        return splitA;
    }
    
    //得到目标函数在当前集合范围内的离散的值
    public Map<String,Long> getTargetValue(){
        
        Map<String,Long> targetValueMap = new HashMap<String,Long>();
        Iterator<ArrayList<String>> iterator = data.iterator();
        while(iterator.hasNext()){
            List<String> tempList = iterator.next();
            String key = tempList.get(tempList.size() - 1);
            Long value = targetValueMap.get(key);
            targetValueMap.put(key, value != null ? ++value : 1L);
        }
        return targetValueMap;
    }
    
    //获得TARGET值
    public static List<String> getTarget(List<ArrayList<String>> data){
        
        List<String> list = new ArrayList<String>();
        for(ArrayList<String> temp : data){
            int index = temp.size() -1;
            String value = temp.get(index);
            list.add(value);
        }
        return list;
    }
    
    //判断当前纯度是否100%
    public static String IsPure(List<String> list){
        
        Set<String> set = new HashSet<String>();
        for(String name :list){
            set.add(name);
        }
        if(set.size() > 1) return null;
        Iterator<String> iterator = set.iterator();
        return iterator.next();
    }
    

}
复制代码

 

测试类,数据集读取以上的分别放到2个List中。

复制代码
package C45Test;

import java.util.ArrayList;
import java.util.List;

import C45Test.DecisionTree.TreeNode;

public class MainC45 {

    private static final List<ArrayList<String>> dataList = new ArrayList<ArrayList<String>>();
    private static final List<String> attributeList = new ArrayList<String>();
    
    public static void main(String args[]){
        
        DecisionTree dt = new DecisionTree();
        TreeNode node = dt.createDT(configData(),configAttribute());
        System.out.println();
    }
}
复制代码

 

大数运算工具类

复制代码
package C45Test;
import java.math.BigDecimal;

public abstract class MathUtils {
    
    //默认余数长度
    private static final int DIV_SCALE = 10;
    
    //受限于DOUBLE长度
    public static double add(double value1,double value2){
        
        BigDecimal big1 = new BigDecimal(String.valueOf(value1));
        BigDecimal big2 = new BigDecimal(String.valueOf(value2));
        return big1.add(big2).doubleValue();
    }
    
    //大数加法
    public static double add(String value1,String value2){
        
        BigDecimal big1 = new BigDecimal(value1);
        BigDecimal big2 = new BigDecimal(value2);
        return big1.add(big2).doubleValue();
    }
    
    public static double div(double value1,double value2){
        
        BigDecimal big1 = new BigDecimal(String.valueOf(value1));
        BigDecimal big2 = new BigDecimal(String.valueOf(value2));
        return big1.divide(big2,DIV_SCALE,BigDecimal.ROUND_HALF_UP).doubleValue();
    }
    
    public static double mul(double value1,double value2){
        
        BigDecimal big1 = new BigDecimal(String.valueOf(value1));
        BigDecimal big2 = new BigDecimal(String.valueOf(value2));
        return big1.multiply(big2).doubleValue();
    }
    
    public static double sub(double value1,double value2){
        
        BigDecimal big1 = new BigDecimal(String.valueOf(value1));
        BigDecimal big2 = new BigDecimal(String.valueOf(value2));
        return big1.subtract(big2).doubleValue();
    }
    
    public static double returnMax(double value1, double value2) {
        
        BigDecimal big1 = new BigDecimal(value1);
        BigDecimal big2 = new BigDecimal(value2);
        return big1.max(big2).doubleValue();
    }
}
复制代码
  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
CAN(Controller Area Network,控制器局域网)总线协议是一种广泛应用于工业自动化、汽车电子等领域的串行通讯协议。其帧格式如下: <img src="https://img-blog.csdnimg.cn/20200925125252655.png" width="400"> CAN总线协议的帧分为标准帧和扩展帧两种,其中标准帧包含11位标识符,扩展帧包含29位标识符。在CAN总线上,所有节点都可以同时发送和接收数据,因此需要在帧中包含发送方和接收方的信息。 帧格式的具体解释如下: 1. 帧起始符(SOF):一个固定的位模式,表示帧的起始。 2. 报文控制(CTRL):包含几个控制位,如IDE、RTR等。其中IDE表示标识符的类型,0表示标准帧,1表示扩展帧;RTR表示远程请求帧,0表示数据帧,1表示远程请求帧。 3. 标识符(ID):11位或29位的标识符,用于区分不同的CAN消息。 4. 控制域(CTL):包含几个控制位,如DLC、EDL等。其中DLC表示数据长度,即数据域的字节数;EDL表示数据长度是否扩展,0表示标准数据帧,1表示扩展数据帧。 5. 数据域(DATA):0~8字节的数据。 6. CRC:用于校验数据是否正确。 7. 确认位(ACK):由接收方发送的确认信息,表示数据是否正确接收。 8. 结束符(EOF):一个固定的位模式,表示帧的结束。 以上就是CAN总线协议的帧格式。在实际应用中,节点之间通过CAN总线进行数据交换,通过解析帧中的各个字段,可以判断消息的发送方、接收方、数据内容等信息。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值