[Java][机器学习]决策树算法

大概介绍

在之前为了处理Iris花的分类算法,写了一个决策树算法,但这个算法局限性比较大,只能用于那一种情况,所以为了增强代码的复用性,在之前的基础上修改了算法,增强了复用性。略有遗憾的是,由于很多模块,比如快速排序,是为了处理Iris花专门写的,所以现在要进行一些转换才行,一定程度上增加了代码的复杂度和冗余度。这个是以后值得注意的地方,即在设计之初就应该注意到代码的复用问题,尽可能早的定义好通用接口,这样就能增强模块的复用度。 修改以后的程序使用起来方便了好多,在使用之前要定义一套变量ProjectInfo,里面包括决策树的深度等参量,用于调节决策树的生成。

举个例子吧,我想用我的程序来分离正态分布随机变量和均匀随机变量。这是我实现用matlab生成的变量,存在excel里面。这里写图片描述
其中黄点和绿点是均匀分布的随机变量,黄点是训练集,绿点是检验集。而蓝点红点是正太分布随机变量,其中红点是训练集,而蓝点是检验集。可以看到这些正态分布的点分成3坨,有3个中心,分别是(5,15),(15,5)以及(10,10)。其中中间那坨的方差要大一些。我想用我的算法来将这两类分开来,那么首先就要定义一个ProjectInfo对象,用来控制决策树的一些相关变量,比如深度等。ProjectInfo的参数有如下几个:

int AttrNum ; //属性值的数目,如有2种属性,则AttrNum=2 ,在这个例子中,原始属性为2种,分别为x,y但是看图可以明显得知用斜线可以更好地分类,我这里又加了两个额外属性,分别为x+y,x-y。此时AttrNum就是4.看官要是高兴,也可以加x+2y,x+3y这些额外属性。当然了,要记得更改这个值
int[] AttrType ; //Attributes[i]表示该值的属性,double为0,boolean为1。这个例子中x,y,x+y,x-y均为double,所以为{0,0,0,0}
String[] SqlColName; //SQL数据库中列名. 这个是为了方便sql读取的,存放sql数据库中的列名
String[] AttrDescription ; //对Arrtibute的描述,SqlColNmae中的名字要与数据库严格对应,而这里则不需要,自己能看懂就行了,属于注释类的
int dataNodeNum ; //数据集的数目。指的是训练集的数目。我这个例子中的训练集总共2000个点,其中1200个均匀分布,800个高斯分布。所以dataNodeNum=2000
String[] classNum ; //数据分属的各个类的名字,如果要从数据库导入,则必须跟数据库中的类名一样,如果不是,则随意
int RuleDeep ; //决策树的深度。如果节点的深度大于该值,则不再分裂
int NodeLimitNum ; //若节点内数据量小于等于该值,则不再分裂
double NodeLimitEntropy ; //若节点内数据熵小于等于该值,则不再分裂
double NodeDeltaEntropy ; //若节点分裂前后数据熵变化小于等于该值,则不再分裂

定义之后,需要训练集以及检验集。考虑到通用性,不再写通用的函数,只需要保证最后格式就行。具体格式为:
1.送入RuleNode函数的必须要是double[][] example格式
2.example.length=dataNodeNum
3.example[i].length=AttrNum+3 其中example[i][0]存放数据序号,从0-(dataNodeNum-1) ; example[i][1]=ClassType,其值从0-(classNum.length-1)。 在这个例子中,由于只有两类点,所以其值为0,1.如果是之前的Iris,其值范围就是0-2 ; example[i][2]=tempClassType. 值只能是0或者1. ;example[i][3-(2+AttrNum)]存放各个属性值。
当然为了方便,还是准备了从sql以及excel中读取数组的接口。当然读取过来的不一定能直接用,需要自己修改。你也可以自己写接口把数组准备好。总之只要格式对就行了。

准备好数组和ProjectInfo以后,生成一个RuleNode的节点root。RuleNode构造函数的接口为
public RuleNode(double[][] input,int deep,String tag,ProjectInfo pInfo)
其中input为训练集,deep为深度,初始深度为0,tag记录节点位置,初始值为“0”。所以例子中的调用方法为
RuleNode root=new RuleNode(trainset,0,”0”,pInfo) ;

之后可以用root.PrintRule(String filename,RuleNode root, ProjectInfo pInfo) 来打印决策树的规则,输出在桌面上。
rule.getNodeNum(root)则可以观察该决策树总共有多少个节点

为了检验决策树的错误率,可以使用Estimate类。
Estimate es=new Estimate(rule,trainset,pInfo) ;
System.out.println(“训练集的错误率为:”+es.ErrRatio);

在这个例子中,我设的参数值为:
int AttrNum=4 ;
int[] AttrType={0,0,0,0} ;
String[] SqlColName={“x”,”y”,”x+y”,”x-y”} ;
String[] AttrDescription={“x”,”y”,”x+y”,”x-y”} ;
int dataNodeNum=2000 ;
String[] classNum={“Average”,”Normal”} ;
int RuleDeep=8 ;
int NodeLimitNum=10 ;
double NodeLimitEntropy=0 ;
double NodeDeltaEntropy=0 ;
最后得到的结果是规则数共有131个节点,训练集的错误率为15.65%,检验集的错误率为19.2%。生成的规则数如下:

x+y<=14.244
| x+y<=12.206
| | y<=4.7925: Class0
| | y>4.7925
| | | y<=4.8215: Class1
| | | y>4.8215: Class0
| x+y>12.206
| | x<=5.204: Class0
| | x>5.204
| | | x<=7.156499999999999
| | | | x<=5.2625: Class1
| | | | x>5.2625
| | | | | x<=5.9145: Class0
| | | | | x>5.9145: Class0
| | | x>7.156499999999999
| | | | x+y<=13.5105: Class0
| | | | x+y>13.5105
| | | | | x+y<=13.524999999999999: Class1
| | | | | x+y>13.524999999999999
| | | | | | x+y<=13.92: Class0
| | | | | | x+y>13.92: Class0
x+y>14.244
| x+y<=23.326500000000003
| | x<=3.101
| | | y<=16.8975
| | | | x<=0.8965000000000001: Class0
| | | | x>0.8965000000000001
| | | | | y<=12.908000000000001: Class0
| | | | | y>12.908000000000001
| | | | | | x<=0.9964999999999999: Class1
| | | | | | x>0.9964999999999999
| | | | | | | x-y<=-10.7015
| | | | | | | | x+y<=16.7165: Class0
| | | | | | | | x+y>16.7165: Class0
| | | | | | | x-y>-10.7015: Class1
| | | y>16.8975: Class0
| | x>3.101
| | | y<=2.8085
| | | | x-y<=14.558
| | | | | x<=16.951
| | | | | | y<=1.1775: Class0
| | | | | | y>1.1775
| | | | | | | x-y<=13.814499999999999
| | | | | | | | x+y<=14.5165: Class1
| | | | | | | | x+y>14.5165: Class0
| | | | | | | x-y>13.814499999999999: Class1
| | | | | x>16.951: Class1
| | | | x-y>14.558
| | | | | y<=2.1395: Class0
| | | | | y>2.1395: Class0
| | | y>2.8085
| | | | x+y<=16.6265
| | | | | y<=12.838000000000001
| | | | | | x<=6.4335
| | | | | | | x-y<=-7.651499999999999: Class0
| | | | | | | x-y>-7.651499999999999: Class0
| | | | | | x>6.4335
| | | | | | | x<=9.96
| | | | | | | | x-y<=3.2985: Class1
| | | | | | | | x-y>3.2985: Class1
| | | | | | | x>9.96
| | | | | | | | x-y<=9.027999999999999: Class0
| | | | | | | | x-y>9.027999999999999: Class1
| | | | | y>12.838000000000001: Class1
| | | | x+y>16.6265
| | | | | x-y<=13.459
| | | | | | x+y<=22.872999999999998
| | | | | | | x-y<=-8.513
| | | | | | | | x-y<=-10.426: Class1
| | | | | | | | x-y>-10.426: Class1
| | | | | | | x-y>-8.513
| | | | | | | | x<=7.8805: Class1
| | | | | | | | x>7.8805: Class1
| | | | | | x+y>22.872999999999998
| | | | | | | x+y<=23.162499999999998
| | | | | | | | x<=17.2825: Class0
| | | | | | | | x>17.2825: Class1
| | | | | | | x+y>23.162499999999998
| | | | | | | | x<=5.6205: Class0
| | | | | | | | x>5.6205: Class1
| | | | | x-y>13.459
| | | | | | x+y<=21.8705
| | | | | | | x+y<=21.7155: Class0
| | | | | | | x+y>21.7155: Class1
| | | | | | x+y>21.8705: Class0
| x+y>23.326500000000003
| | x+y<=27.7235
| | | y<=17.5545
| | | | y<=10.692499999999999
| | | | | x+y<=24.683999999999997
| | | | | | x+y<=23.731499999999997: Class0
| | | | | | x+y>23.731499999999997
| | | | | | | x+y<=23.7855: Class1
| | | | | | | x+y>23.7855
| | | | | | | | x+y<=24.435000000000002: Class0
| | | | | | | | x+y>24.435000000000002: Class0
| | | | | x+y>24.683999999999997
| | | | | | x+y<=26.735500000000002: Class0
| | | | | | x+y>26.735500000000002
| | | | | | | x+y<=26.7965: Class1
| | | | | | | x+y>26.7965: Class0
| | | | y>10.692499999999999
| | | | | x+y<=25.2045
| | | | | | x-y<=-2.4989999999999997
| | | | | | | y<=14.745000000000001
| | | | | | | | x<=11.157499999999999: Class0
| | | | | | | | x>11.157499999999999: Class1
| | | | | | | y>14.745000000000001
| | | | | | | | y<=15.202: Class1
| | | | | | | | y>15.202: Class0
| | | | | | x-y>-2.4989999999999997
| | | | | | | x<=12.005500000000001
| | | | | | | | x<=10.674: Class0
| | | | | | | | x>10.674: Class1
| | | | | | | x>12.005500000000001
| | | | | | | | x<=12.3055: Class0
| | | | | | | | x>12.3055: Class1
| | | | | x+y>25.2045
| | | | | | x<=10.908000000000001: Class0
| | | | | | x>10.908000000000001
| | | | | | | x+y<=25.803: Class0
| | | | | | | x+y>25.803
| | | | | | | | x<=11.1095: Class1
| | | | | | | | x>11.1095: Class0
| | | y>17.5545: Class0
| | x+y>27.7235
| | | y<=12.2935
| | | | x<=17.162: Class1
| | | | x>17.162: Class0
| | | y>12.2935: Class0

附录:代码

ProjectInfo

package Classification;

public class ProjectInfo {

    public int AttrNum ;       //属性值的数目,如有2种属性,则AttrNum=2 
    public int[] AttrType ;  //Attributes[i]表示该值的属性,double为0,boolean为1
    public String[] SqlColName;  //SQL数据库中列名
    public String[] AttrDescription ; //对Arrtibute的描述
    public int dataNodeNum ;   //数据集的数目
    public String[] classNum ; //数据分属的各个类的名字
    //正常情况下,double[i][0]存储SetNum,double[i][1]存储Type,double[i][1]存储tempType
    //double[i][2-(2+AttrNum-1)]存储属性值
    //public int addAttrNum ;   //自定义属性数目 ,默认为0 若为>0则启动自定义属性函数
    public int RuleDeep ;
    public int NodeLimitNum ;  //若节点内数据数小于等于该值,则不再分裂
    public double NodeLimitEntropy ; //若节点内数据熵小于等于该值,则不再分裂
    public double NodeDeltaEntropy ; //若节点内数据熵小于等于该值,则不再分裂


    public ProjectInfo(){      //空对象
        this.AttrNum=0 ;
        this.AttrType=null ;
        this.SqlColName=null ;
        this.AttrDescription=null ;
        this.dataNodeNum=0 ;
        this.classNum=null ;
        //this.addAttrNum=0 ;
        this.RuleDeep=100 ;
        this.NodeLimitNum=1 ;
        this.NodeLimitEntropy=0 ;
        this.NodeDeltaEntropy=1 ;
    }

    public ProjectInfo(int AttrNum,int[] AttrType,String[] SqlColName,
            String[] AttrDescription,int dataNodeNum,String[] classNum,
            int RuleDeep,int NodeLimitNum,
            double NodeLimitEntropy,double NodeDeltaEntropy){  

        this.AttrNum=AttrNum ;
        if(AttrType.length!=this.AttrNum) System.out.println("ERROR:ProjectInfo_AttrType输入值数目错误");
        this.AttrType=AttrType ;
        if(SqlColName.length!=this.AttrNum) System.out.println("ERROR:ProjectInfo_SqlColName输入值数目错误");
        this.SqlColName=SqlColName ;
        if(AttrDescription.length!=this.AttrNum) System.out.println("ERROR:ProjectInfo_AttrDescription输入值数目错误");
        this.AttrDescription=AttrDescription ;
        this.dataNodeNum=dataNodeNum ;
        this.classNum=classNum ;
        //this.addAttrNum=addAttrNum ;
        this.RuleDeep=RuleDeep ;
        this.NodeLimitNum=NodeLimitNum ;
        this.NodeLimitEntropy=NodeLimitEntropy ;
        this.NodeDeltaEntropy=NodeDeltaEntropy ;
    }

    public static void main(String[] args){

        int AttrNum=4 ;
        int[] AttrType={0,0,0,0} ;
        String[] SqlColName={"x","y","x+y","x-y"} ;
        String[] AttrDescription={"x","y","x+y","x-y"} ;
        int dataNodeNum=2000 ;
        String[] classNum={"Average","Normal"} ;
        //int addAttrNum=0 ;
        int RuleDeep=8 ;
        int NodeLimitNum=10 ;
        double NodeLimitEntropy=0 ;
        double NodeDeltaEntropy=0 ;

        ProjectInfo pInfo=new ProjectInfo(AttrNum,AttrType,SqlColName,
                AttrDescription,dataNodeNum,classNum,RuleDeep,
                NodeLimitNum,NodeLimitEntropy,NodeDeltaEntropy) ;

        FileIO ep=new FileIO() ;
        int len=2000 ;
        int width=3 ; //excel中的列数 
        int[] nodea={0,0} ;
        int[] nodeb={width-1,len-1} ;
        double[][] get1=ep.getArray("origin",0,nodea, nodeb) ;
        double[][] trainset=new double[len][pInfo.AttrNum+3] ;
        for(int i=0;i<len;i++){
            trainset[i][0]=i+1 ;
            trainset[i][1]=get1[i][2] ;
            trainset[i][2]=get1[i][2] ;
            trainset[i][3]=get1[i][0] ;
            trainset[i][4]=get1[i][1] ;
            trainset[i][5]=get1[i][0]+get1[i][1] ;
            trainset[i][6]=get1[i][0]-get1[i][1] ;
        }
        len=1000 ;
        width=3 ;
        nodeb[0]=width-1 ;
        nodeb[1]=len-1 ;
        double[][] get2=ep.getArray("origin",1,nodea, nodeb) ;
        double[][] examset=new double[len][pInfo.AttrNum+3] ;
        for(int i=0;i<len;i++){
            examset[i][0]=i+1 ;
            examset[i][1]=get2[i][2] ;
            examset[i][2]=get2[i][2] ;
            examset[i][3]=get2[i][0] ;
            examset[i][4]=get2[i][1] ;
            examset[i][5]=get2[i][0]+get2[i][1] ;
            examset[i][6]=get2[i][0]-get2[i][1] ;
        }   

        RuleNode root=new RuleNode(trainset,0,"0",pInfo) ;
        root.PrintRule("rule", root, pInfo);
        RuleNode rule=root ;
        System.out.println("规则数的节点数为 "+rule.getNodeNum(root));

        Estimate es=new Estimate(rule,trainset,pInfo) ;
        System.out.println("训练集的错误率为:"+es.ErrRatio);
        es=new Estimate(rule,examset,pInfo) ;
        System.out.println("检验集的错误率为:"+es.ErrRatio);


        //es=new Estimate(rule,trainset,pInfo) ;
        //System.out.println("训练集的错误率为:"+es.ErrRatio);


    }


}

RuleNode

package Classification;

public class RuleNode {

    public int deep ;
    public double formerEntropy ;  
    public double[][] datalist ;
    public String tag ;

    public int nodeType=-1 ;
    //public int minSetNum=1 ;   //一个有效枝节点的最少数据数 若少于这个,则不再分裂

    public int divideType=-1 ;
    public double valveValue=-1 ;
    public double deltaEntropy=1 ;

    public RuleNode leftChild=null ;
    public RuleNode rightChild=null ;
    public double laterEntropy=2 ;

    public RuleNode(double[][] input,int deep,String tag,ProjectInfo pInfo){
        //System.out.println("正在建立第 "+tag+" 号节点");                                  //used for debug
        //System.out.println("第 "+tag+" 号节点数组数目:"+input.length);    

        this.deep=deep ;
        this.tag=tag ;
        this.datalist=input ;
        this.formerEntropy=getDataListEntropy(input) ; //undefined
        this.nodeType=-1 ;

        if ((this.deep>pInfo.RuleDeep)||(this.datalist.length<=pInfo.NodeLimitNum)||(this.formerEntropy<=pInfo.NodeLimitEntropy)){
            //深度过大或者点数过小,或者分来的数组足够纯净,则不再分类
            this.leftChild=this.rightChild=null ;
            int temp=decideType(this.datalist) ;    //undefined
            if ((temp==0)||(temp==1)) this.nodeType=temp ;
            else System.out.println("ERROR:函数decideType输出值不合法") ;
        }else{
            Hunt hunt=new Hunt(input,pInfo) ;             //undefined
            this.divideType=hunt.type ;
            this.valveValue=hunt.value_value ;
            this.laterEntropy=hunt.min_entropy ;
            this.deltaEntropy=this.formerEntropy-this.laterEntropy ;

            if (this.deltaEntropy<pInfo.NodeDeltaEntropy){
                this.leftChild=this.rightChild=null ; //if deltaEntropy<0.05 or deep>5 no longer continue
                int temp=decideType(this.datalist) ;
                if ((temp==0)||(temp==1)) this.nodeType=temp ;
                else System.out.println("ERROR:函数decideType输出值不合法") ;
            }else{
                //System.out.println("tag1") ;              //used for debug
                double[][] leftList=Divide(input,this.divideType,this.valveValue,0,pInfo) ;
                double[][] rightList=Divide(input,this.divideType,this.valveValue,1,pInfo) ;

                //if (tag=="001") System.out.println(leftChild==null) ;//used for debug
                //if (leftList==null) System.out.println(tag+"节点的左子树为空") ; else System.out.println(tag+"节点的左子树长"+leftList.length) ;     //used for debug
                //if (rightList==null) System.out.println(tag+"节点的右子树为空") ; else System.out.println(tag+"节点的右子树长"+rightList.length) ;   //used for debug

                //if ((leftList==null)||(rightList==null)) this.leftChild=this.rightChild=null ;
                if ((leftList.length==0)||(rightList.length==0)) {
                    this.leftChild=this.rightChild=null ;
                    int temp=decideType(this.datalist) ;
                    if ((temp==0)||(temp==1)) this.nodeType=temp ;
                    else System.out.println("ERROR:函数decideType输出值不合法") ;
                }
                else{
                    this.leftChild=new RuleNode(leftList,deep+1,tag+'0',pInfo) ;
                    this.rightChild=new RuleNode(rightList,deep+1,tag+'1',pInfo) ;
                }
            }
        }
    }

    public static double[][] Divide(double[][] input,int attribute,double valve,int methodtype,ProjectInfo pInfo){
        double[][] rs=null ;
        //通过attribute value type来将input分成两部分
        if (methodtype==0){ //此处为methodtype=1时的情况,也就是attr value<valve的情况
            int num=0 ;
            for(int i=0;i<input.length;i++){
                if ((attribute>=0)&&(attribute<pInfo.AttrNum)){
                    if (input[i][attribute+3]<=valve) num++ ;
                }else System.out.println("ERROR:The value of attribute value illegal");
            }
            rs=new double[num][pInfo.AttrNum+3] ;
            int index=0 ;
            for(int i=0;i<input.length;i++){
                if ((attribute>=0)&&(attribute<pInfo.AttrNum)){
                    if (input[i][attribute+3]<=valve) rs[index++]=input[i] ;
                }else System.out.println("ERROR:The value of attribute value illegal");
            }
            return rs ;
        }else if(methodtype==1){
            int num=0 ;
            for(int i=0;i<input.length;i++){
                if ((attribute>=0)&&(attribute<pInfo.AttrNum)){
                    if (input[i][attribute+3]>valve) num++ ;
                }else System.out.println("ERROR:The value of attribute value illegal");
            }
            rs=new double[num][pInfo.AttrNum+3] ;
            int index=0 ;
            for(int i=0;i<input.length;i++){
                if ((attribute>=0)&&(attribute<pInfo.AttrNum)){
                    if (input[i][attribute+3]>valve) rs[index++]=input[i] ;
                }else System.out.println("ERROR:The value of attribute value illegal");
            }
            return rs ;
        }else System.out.println("ERROR:RuleNode_Divide_methodtype value illegal");
        return rs ;
    }

    public int getNodeNum(RuleNode node){
        if (node.nodeType==-1){
            int num1=getNodeNum(node.leftChild) ;
            int num2=getNodeNum(node.rightChild) ;
            return num1+num2+1 ;
        }else{
            return 1 ;
        }
    }

    public void PrintRule(String filename,RuleNode node,ProjectInfo pInfo){
        String out="" ;
        for(int i=0;i<node.deep;i++) out+="| " ;
        if (node.leftChild.nodeType==-1) {
            FileIO.PrintTxtln(filename, out+pInfo.AttrDescription[node.divideType]+"<="+node.valveValue);
            PrintRule(filename,node.leftChild,pInfo) ;
        }else{
            FileIO.PrintTxtln(filename, out+pInfo.AttrDescription[node.divideType]+"<="+node.valveValue+": Class"+node.leftChild.nodeType);
        }
        if (node.rightChild.nodeType==-1) {
            FileIO.PrintTxtln(filename, out+pInfo.AttrDescription[node.divideType]+">"+node.valveValue);
            PrintRule(filename,node.rightChild,pInfo) ;
        }else{
            FileIO.PrintTxtln(filename, out+pInfo.AttrDescription[node.divideType]+">"+node.valveValue+": Class"+node.rightChild.nodeType);
        }
    }

    private static double getDataListEntropy(double[][] input){  //根据输入的二维数组确定datalist的熵
        DataProperty dp=new DataProperty() ;
        double rs_entropy=-1 ;
        //通过tempType的值来计算irisdata数组的熵
        //tempType只有3个值,0表示类1,1表示类2,-1表示其他类 一般用于表示异常
        int num1=0,num2=0 ;
        for(int i=0;i<input.length;i++){
            if(input[i][2]==0) num1++ ;
            if(input[i][2]==1) num2++ ;
        }
        rs_entropy=dp.getEntropy(num1, num2) ;
        return rs_entropy ;
    }

    private static int decideType(double[][] input){
        int rs=-1 ;
        int num0=0,num1=0 ;
        for(int i=0;i<input.length;i++){
            if (input[i][2]==0) num0++ ;
            if (input[i][2]==1) num1++ ;
        }
        if (num0<num1) rs=1 ; //有条件的话可以吧num0=num1时node的归属用随机数来实现
        else rs=0 ;  

        return rs ;
    }

}

Hunt

package Classification;

public class Hunt {
    public double min_entropy ;
    public double value_value ;
    public int type ;

    public Hunt(double[][] dataset,ProjectInfo pInfo){
        //1. calculate the entropy of initial dataset
        //2. find best attritube from 4 
        double[][] rs=new double[pInfo.AttrNum][2] ;

        int mintype=-1 ;
        double minentropy=1 ;
        double valve_value=-1 ;

        for(int i=0;i<pInfo.AttrNum;i++){
            if (pInfo.AttrType[i]==0) rs[i]=FindBestValve(preDeal(dataset,i,pInfo),pInfo) ;
            else if (pInfo.AttrType[i]==1) {rs[i]=BoolAttr(dataset,i) ;} ///////// undefined
            else System.out.println("ERROR:Hunt_Hunt_pInfo.AttrType["+i+"]设定有问题");
            //rs[i][0]=entropy rs[i][1]=valve
            if(rs[i][0]<minentropy){
                minentropy=rs[i][0] ;
                valve_value=rs[i][1] ;
                mintype=i ;
            }
        }
        //3. find the best one and output
        this.min_entropy=minentropy ;
        this.value_value=valve_value ;
        this.type=mintype ;
    }

    private static double[] BoolAttr(double[][] input,int Attr){
        int total_len=input.length ;
        int len1=0,len2=0 ;
        for(int i=0;i<total_len;i++){
            if (input[i][3+Attr]==0) len1++ ;
            if (input[i][3+Attr]==1) len2++ ;
        }
        if((len1+len2)!=total_len) System.out.println("ERROR:Hunt_BoolAttr_bool变量中有异值") ;
        DataProperty dp=new DataProperty() ;
        double[] rs=new double[2] ;
        rs[0]=dp.getEntropy(len1, len2) ;
        rs[1]=0.5 ;
        return rs ;
    }

    private static double[][] preDeal(double[][] dataset,int attr,ProjectInfo pInfo){  //transfer IrisData[] to int[][] to fit the followign processing
        if ((attr<pInfo.AttrNum)&&(attr>=0)){
            double[][] rs=new double[dataset.length][3] ; //3 attributes:Attribute Value,Number,,tempType
            for(int i=0;i<dataset.length;i++){
                rs[i][0]=dataset[i][attr+3] ;
                rs[i][1]=dataset[i][0] ;
                rs[i][2]=dataset[i][2] ;         //ATTENTION the taken value is tempTyoe!
            }
            return rs ;
        }else {System.out.println("ERROR:Hunt_preDeal_type输入值不正确");return null ;}
    }
    private static double[] FindBestValve(double[][] input,ProjectInfo pInfo){  
        //要考虑Type的多值性,最好只有两个值
        //print(input) ;                            // used for debug
        double[][] sorted=QuickSort(input,0,input.length-1) ; //1st step:sort the input array
        //接下来应该要在不同值区间内循环,挑一个熵值最小的。
        double min_entropy=2 ;
        double valve_value=-1 ;
        //System.out.println(sorted==null);         // used for debug
        for(int i=0;i<sorted.length-1;i++){
            // calculate the entropy of the division whose valve is between i and i+1
            if (sorted[i][0]!=sorted[i+1][0]){      //避免在两个相同值之间分析的情况
                double temp_entropy=CalculateEntropy(sorted,i) ; 
                if (temp_entropy<min_entropy){
                    min_entropy=temp_entropy ;
                    valve_value=(sorted[i][0]+sorted[i+1][0])/2 ;
                }
            }
        }
        double[] rs=new double[2] ;
        rs[0]=min_entropy ;
        rs[1]=valve_value ;
        return rs ;
    }

    private static double CalculateEntropy(double[][] sorted,int i) {  //can only deal with the data which have only two classes
        DataProperty dp=new DataProperty() ; //initialization of dataproperty

        double rs_entropy=-1 ;
        //double tagclass=sorted[0][2] ;

        int num1=0 ;
        int num2=0 ;
        for(int x=0;x<i+1;x++){
            if(sorted[x][2]==0) num1++ ;
            else if(sorted[x][2]==1) num2++ ;
            else System.out.println("ERROR from CalculateEntropy: the value of tempType of a item is -1");
        }
        double entropy1=dp.getEntropy(num1,num2) ;
        int tnum1=num1+num2 ; //total number of the former sequence

        num1=0 ;
        num2=0 ;
        for(int x=i+1;x<sorted.length;x++){
            if(sorted[x][2]==0) num1++ ;
            else if(sorted[x][2]==1) num2++ ;
            else System.out.println("ERROR from CalculateEntropy: the value of tempType of a item is -1");  
        }
        double entropy2=dp.getEntropy(num1,num2) ;
        int tnum2=num1+num2 ;
        rs_entropy=(entropy1*tnum1+entropy2*tnum2)/(tnum1+tnum2) ;
        return rs_entropy ;
    } 

    private static double[][] QuickSort(double[][] input,int low,int high){
        if(low>=high) return null ;
        int first=low ;
        int last=high ;
        double[] key=input[low] ;
        while(first<last){
            while((first<last)&&(input[last][0]>=key[0])) --last ;
            input[first]=input[last] ;
            while((first<last)&&(input[first][0]<=key[0])) ++first ;
            input[last]=input[first] ;
        }
        input[first]=key ;

        double[][] res1,res2 ;

        if (first-1>low) {res1=QuickSort(input,low,first-1) ;}
        else if(first-1==low) {double[][] temp={input[low]} ;res1=temp ;}
        else{res1=null ;}

        if(high>first+1){res2=QuickSort(input,first+1,high) ;}
        else if(high==first+1){double[][] temp={input[high]} ;res2=temp ;}
        else{res2=null ;}

        double[][] finalres ;
        finalres=Combine(res1,res2,key) ;
        return finalres ;
    }
    private static double[][] Combine(double[][] res1,double[][] res2,double[] key){
        int len1,len2 ;
        if(res1==null) len1=0 ;
        else len1=res1.length ;
        if(res2==null) len2=0 ;
        else len2=res2.length ;

        double[][] res=new double[len1+len2+1][3] ;
        int index=0 ;
        for(int i=0;i<len1;i++) res[index++]=res1[i] ;
        res[index++]=key ;
        for(int i=0;i<len2;i++) res[index++]=res2[i] ;
        return res ;
    }
}

Estimate

package Classification;

import java.util.*; ;
public class Estimate {

    List<double[]> list0 ;
    List<double[]> list1 ;
    double[][] array0 ;
    double[][] array1 ;
    RuleNode examtree ;
    double ErrRatio ;

    public Estimate(RuleNode rule,double[][] examset,ProjectInfo pInfo){
        this.list0=new ArrayList<double[]>() ;
        this.list1=new ArrayList<double[]>() ;
        this.examtree=examTree(rule,examset,pInfo) ;
        this.ErrRatio=getErrRatio(this.list0,this.list1) ;
        this.array0=convert(list0,pInfo) ;
        this.array1=convert(list1,pInfo) ;
    }

    private double getErrRatio(List list0,List list1){
        double len1=list0.size() ;
        double len2=list1.size() ;
        double errnum1=0,errnum2=0 ;
        for(int i=0;i<len1;i++){
            double[] temp=(double[])list0.get(i) ; 
            if (temp[2]==1) errnum1++ ;
        }
        for(int i=0;i<len2;i++){
            double[] temp=(double[])list1.get(i) ;
            if(temp[2]==0) errnum2++ ;
        }
        double erratio=(errnum1+errnum2)/(len1+len2) ;
        return erratio  ;       
    }
    private double[][] convert(List list,ProjectInfo pInfo){
        int len=list.size() ;
        double[][] rs=new double[len][3+pInfo.AttrNum] ;
        for(int i=0;i<len;i++){
            rs[i]=(double[])list.get(i) ;
        }
        return rs ;
    }

    private RuleNode examTree(RuleNode node,double[][] data,ProjectInfo pInfo){
        node.datalist=data ;
        node.formerEntropy=getDataListEntropy(data) ;
        if (node.nodeType==-1) { //this node is not a leaf node
            double[][] left=RuleNode.Divide(data, node.divideType, node.valveValue, 0,pInfo) ;
            double[][] right=RuleNode.Divide(data, node.divideType, node.valveValue, 1,pInfo) ;
            if (left.length==0) node.leftChild=null ;
            else node.leftChild=examTree(node.leftChild,left,pInfo) ;
            if(right.length==0) node.rightChild=null ;
            else node.rightChild=examTree(node.rightChild,right,pInfo) ;
            return node ;
        }else{    // this node is a leaf node
            node.leftChild=null ;
            node.rightChild=null ;

            int len=node.datalist.length ;   //将判定为0或者1的类调入到list0,list1中
            int num=0 ;
            if (node.nodeType==0) {
                for(int i=0;i<len;i++) list0.add(node.datalist[i]) ;
            }
            if (node.nodeType==1) {
                for(int i=0;i<len;i++) list1.add(node.datalist[i]) ;
            }
            return node ;
        }
    }


    private static double getDataListEntropy(double[][] input){  //根据输入的二维数组确定datalist的熵
        DataProperty dp=new DataProperty() ;
        double rs_entropy=-1 ;
        //通过tempType的值来计算irisdata数组的熵
        //tempType只有3个值,0表示类1,1表示类2,-1表示其他类 一般用于表示异常
        int num1=0,num2=0 ;
        for(int i=0;i<input.length;i++){
            if(input[i][2]==0) num1++ ;
            if(input[i][2]==1) num2++ ;
        }
        rs_entropy=dp.getEntropy(num1, num2) ;
        return rs_entropy ;
    }

}

FileOP

package Classification;

//import java.io.File;
import java.io.* ;

import jxl.Workbook;
import jxl.write.Label;
import jxl.write.WritableSheet;
import jxl.write.WritableWorkbook;
import jxl.Cell ;
import jxl.Sheet ;

public class FileIO {
    public FileIO(){

    }

    public void PrintDoubleArray(double[][] input,String filename,ProjectInfo pInfo){
        try{
            String rootname="C:\\Users\\multiangle\\Desktop\\" ;  
            String path=rootname+filename+".xls" ;
            File file=new File(path) ;
            WritableSheet sheet ;
            WritableWorkbook book ;
            if (file.exists()) {
                Workbook wb=Workbook.getWorkbook(file) ;
                book=Workbook.createWorkbook(file, wb) ;
                int sheetnum=book.getNumberOfSheets() ;
                sheet=book.createSheet("第"+sheetnum+"页", sheetnum) ;
                System.out.println("正在第"+sheetnum+"页打印double数组");
            }else {
                book=Workbook.createWorkbook(new File(path)) ;
                sheet=book.createSheet("第0页", 0) ;
                System.out.println("正在第0页打印double数组");
            }
            //System.out.println("已获取到需要的表单");

            String[] name=new String[3+pInfo.AttrNum] ;
            name[0]="SetNum" ;
            name[1]="Type" ;
            name[2]="tempType" ;
            for(int i=0;i<pInfo.AttrNum;i++){
                name[3+i]=pInfo.AttrDescription[i] ;
            }
            for(int i=0;i<3+pInfo.AttrNum;i++){
                Label temp=new Label(i,0,name[i]) ;
                sheet.addCell(temp);
            }

            int len=input.length ;
            int row=1 ;
            for(int i=0;i<len;i++){
                for(int j=0;j<3+pInfo.AttrNum;j++){
                    Label temp=new Label(j,row+i,String.valueOf(input[i][j])) ;
                    sheet.addCell(temp);
                }
            }
            book.write() ;
            book.close(); 
        }catch(Exception e){
            System.out.println(e) ;
            System.out.println("ERROR:ExcelPrint") ;
        }
    }

    public void PrintDoubleArray(double[][] input,String filename,String description,ProjectInfo pInfo){
        try{
            String rootname="C:\\Users\\multiangle\\Desktop\\" ;  
            String path=rootname+filename+".xls" ;
            File file=new File(path) ;
            WritableSheet sheet ;
            WritableWorkbook book ;
            if (file.exists()) {
                Workbook wb=Workbook.getWorkbook(file) ;
                book=Workbook.createWorkbook(file, wb) ;
                int sheetnum=book.getNumberOfSheets() ;
                sheet=book.createSheet("第"+sheetnum+"页", sheetnum) ;
                System.out.println("正在第"+sheetnum+"页打印double数组");
            }else {
                book=Workbook.createWorkbook(new File(path)) ;
                sheet=book.createSheet("第0页", 0) ;
                System.out.println("正在第0页打印double数组");
            }
            //System.out.println("已获取到需要的表单");

            Label descrip=new Label(0,0,description) ;
            sheet.addCell(descrip);
            String[] name=new String[3+pInfo.AttrNum] ;
            name[0]="SetNum" ;
            name[1]="Type" ;
            name[2]="tempType" ;
            for(int i=0;i<pInfo.AttrNum;i++){
                name[3+i]=pInfo.AttrDescription[i] ;
            }
            for(int i=0;i<3+pInfo.AttrNum;i++){
                Label temp=new Label(i,1,name[i]) ;
                sheet.addCell(temp);
            }

            int len=input.length ;
            int row=2 ;
            for(int i=0;i<len;i++){
                for(int j=0;j<3+pInfo.AttrNum;j++){
                    Label temp=new Label(j,row+i,String.valueOf(input[i][j])) ;
                    sheet.addCell(temp);
                }
            }
            book.write() ;
            book.close(); 
        }catch(Exception e){
            System.out.println(e) ;
            System.out.println("ERROR:ExcelPrint") ;
        }
    }

    public double[][] getArray(String filename,int sheetnum,int[] nodea,int[] nodeb){
        int left=nodea[0] ;
        int top=nodea[1] ;
        int right=nodeb[0] ;
        int bottom=nodeb[1] ;
        int len=bottom-top+1 ;
        int width=right-left+1 ;
        double[][] rs=new double[len][width] ;

        String root="C:\\Users\\multiangle\\Desktop\\" ; //基本目录为桌面
        String path=root+filename+".xls" ;
        File file=new File(path) ;
        if (!file.exists()) {System.out.println("ERROR:ExcelIO_getArray_File not exists");return null ;}
        else{
            try{
                Workbook book=Workbook.getWorkbook(file) ;
                Sheet sheet=book.getSheet(sheetnum) ;
                for(int i=0;i<len;i++){
                    for(int j=0;j<width;j++){
                        Cell cell=sheet.getCell(j,i) ;
                        double temp=Double.parseDouble(cell.getContents()) ;
                        rs[i][j]=temp ;
                    }
                }
                return rs ;
            }catch(Exception e){
                System.out.println("ERROR:ExcelIO_getArray");
                System.out.println(e);
                return null ;
            }
        }
    }

    public static void PrintTxtln(String filename,String line){
        try{
            String root="C:\\Users\\multiangle\\Desktop\\" ; //基本目录为桌面
            String path=root+filename+".txt" ;
            FileWriter out=new FileWriter(path,true) ;
            out.write(line+"\r\n");
            out.close(); 
        }catch(Exception e){
            System.out.println(e);
        }

    }


}

dataSqlGet


package Classification;

import java.sql.* ;
import Classification.ProjectInfo ;

public class dataSqlGet {

    public double[][] dataset ;

    public dataSqlGet(String DatabaseName,ProjectInfo pInfo){
        ResultSet rs=getResultSet(DatabaseName) ;
        this.dataset=ResultDeal(rs,pInfo) ;
    }

    private static ResultSet getResultSet(String DatabaseName){
        String JDriver="com.microsoft.sqlserver.jdbc.SQLServerDriver";//SQL数据库引擎
        String connectDB="jdbc:sqlserver://127.0.0.1:1433;DatabaseName=multiangle";//数据源

        try{
            Class.forName(JDriver);//加载数据库引擎,返回给定字符串名的类
            System.out.println("数据库驱动成功");
        }catch(ClassNotFoundException e){ //e.printStackTrace();
            System.out.println("加载数据库引擎失败");
            System.out.println(e);
        }     

        ResultSet rs ;
        try{
            String user="sa" ;
            String password="admin" ;
            Connection con=DriverManager.getConnection(connectDB,user,password);
            System.out.println("数据库连接成功");
            Statement stmt=con.createStatement() ;
            //String query="select ROW_NUMBER()over(order by class)as row,* from dbo.[bezdekIris.data]" ;
            String query="select ROW_NUMBER()over(order by class)as row,* from "+DatabaseName ;
            rs=stmt.executeQuery(query) ;
            return rs ;
        }catch(SQLException e){
            System.out.println(e) ;
            System.out.println("数据库内容读取失败");
            return null ;
        }
    }

    public static double[][] ResultDeal(ResultSet rs,ProjectInfo pInfo){
        System.out.println("SqlColName.length"+pInfo.dataNodeNum) ;
        try {
            int len=pInfo.dataNodeNum ;
            System.out.println("len: "+len);
            double[][] dataset=new double[pInfo.dataNodeNum][pInfo.AttrNum+3] ;
            int num=0 ;
            while((num<len)&&(rs.next())){
                dataset[num][0]=Integer.parseInt(rs.getString("row")) ; 

                String name=rs.getString("Class") ; //这个在录入数据的时候要特别注意
                int type=-1 ;
                int namelen=pInfo.classNum.length ;
                for(int i=0;i<namelen;i++){
                    if(name.equals(pInfo.classNum[i])) type=i ;
                }
                dataset[num][1]=type ;   //Type
                dataset[num][2]=-1 ;     //tempType
                //0-> SetNum ;1->Type;2->tempType;3-(3+attr)->attrvalue
                for(int i=0;i<pInfo.SqlColName.length;i++){
                    dataset[num][i+3]=Double.parseDouble(rs.getString(pInfo.SqlColName[i])) ;
                }

                num++ ;
                //System.out.println(setnum+"       "+SL+"      "+SW+"      "+PL+"      "+PW+"      "+type) ;
            }
            System.out.println("ResultSet 解析完毕");
            return dataset ;
        } catch (SQLException e) {
            System.out.println("ResultSet 解析出错");
            System.out.println(e);
            return null ;
        }
    }
}

DataProperty

package Classification;


public class DataProperty {

    public double getGini(int[] data){
        int len=data.length ;
        double sum=0 ;
        for(int i=0;i<len;i++)  sum+=data[i] ;
        double pre_gini=0 ;
        for(int i=0;i<len;i++)  pre_gini+= (data[i]/sum)*(data[i]/sum) ;
        double gini=1-pre_gini ;
        return gini ;
    }
    public double getGini(int a,int b){
        double c=a+b ;
        double gini=1-(a/c)*(a/c)-(b/c)*(b/c) ;
        return gini ;
    }

    public double getEntropy(int[] data){
        int len=data.length ;
        double sum=0 ;
        for(int i=0;i<len;i++)  sum+=data[i] ;  //get the summary of all data
        double pre_entro=0 ;
        for(int i=0;i<len;i++) {
            if (data[i]!=0){
                pre_entro+=(data[i]/sum)*Math.log(data[i]/sum)/Math.log(2) ;
            }
        }
        double entro=-pre_entro ;
        return entro ;
    }
    public double getEntropy(int ina,int inb){
        double a=(double)ina ;
        double b=(double)inb ;
        double entro ;
        if((a*b)!=0){
            double c=a+b ;
            double a1=(a/c)*mathLog2(a/c) ;
            double b1=(b/c)*mathLog2(b/c) ;
            entro=-a1-b1 ;
            return entro ;
        }else{  
            entro=0 ;
            return entro ;
        }
    }

    //inner methods----------------------------------------------------
    private static double mathLog(double data,double bottom){
        return Math.log(data)/Math.log(bottom) ;
    }
    private static double mathLog2(double data){
        return Math.log(data)/Math.log(2) ;
    }
}

生成数据集的matlab程序

clear 
clc
a(1:3,1:2000)=0 ;
a(1,1:1200)=20*rand(1200,1) ;
a(2,1:1200)=20*rand(1200,1) ;
a(3,1:1200)=0 ;

a(1,1201:1400)=5+1.5.*randn(200,1) ;
a(2,1201:1400)=15+1.5.*randn(200,1) ;
a(1,1401:1600)=15+1.5.*randn(200,1) ;
a(2,1401:1600)=5+1.5.*randn(200,1) ;
a(1,1601:2000)=10+2.*randn(400,1) ;
a(2,1601:2000)=10+2.*randn(400,1) ;

a(3,1201:2000)=1 ;



plot(a(1,1:1200),a(2,1:1200),'y.') 
hold on
plot(a(1,1201:2000),a(2,1201:2000),'r.') 

c=a' ;
xlswrite('C:\Users\multiangle\Desktop\origin.xlsx',c,1)

b(1:3,1:1000)=0 ;
b(1,1:600)=20*rand(600,1) ;
b(2,1:600)=20*rand(600,1) ;
b(3,1:600)=0 ;

b(1,601:700)=5+1.5.*randn(100,1) ;
b(2,601:700)=15+1.5.*randn(100,1) ;
b(1,701:800)=15+1.5.*randn(100,1) ;
b(2,701:800)=5+1.5.*randn(100,1) ;
b(1,801:1000)=10+2.*randn(200,1) ;
b(2,801:1000)=10+2.*randn(200,1) ;

b(3,601:1000)=1 ;

hold on
plot(b(1,1:600),b(2,1:600),'g.')
hold on
plot(b(1,601:1000),b(2,601:1000),'b.')
d=b'
xlswrite('C:\Users\multiangle\Desktop\origin.xlsx',d,2)

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
决策树算法是一种常用的机器学习算法,它可以用于分类和回归问题。在Java中,你可以使用Weka库来实现决策树算法。 Weka是一个开源的机器学习库,提供了丰富的机器学习算法和工具。下面是使用Weka库实现决策树算法的步骤: 1. 导入Weka库:首先,你需要在Java项目中导入Weka库。你可以在Weka官方网站上下载Weka的JAR文件,并将其添加到你的项目中。 2. 加载数据:使用Weka库,你可以从文件或其他数据源加载数据集。Weka支持多种数据格式,如ARFF、CSV等。你可以使用`Instances`类来表示数据集。 3. 构建决策树模型:使用`J48`类来构建决策树模型。`J48`是Weka中实现的C4.5算法,它是一种常用的决策树算法。你可以设置一些参数来调整模型的行为,如设置最小叶子数、剪枝等。 4. 训练模型:使用加载的数据集来训练决策树模型。你可以使用`buildClassifier`方法来进行训练。 5. 进行预测:训练完成后,你可以使用训练好的模型来进行预测。你可以使用`classifyInstance`方法来对新的实例进行分类预测。 下面是一个简单的示例代码,展示了如何使用Weka库实现决策树算法: ```java import weka.core.Instances; import weka.classifiers.trees.J48; import weka.core.converters.ConverterUtils.DataSource; public class DecisionTreeExample { public static void main(String[] args) throws Exception { // 加载数据集 DataSource source = new DataSource("path/to/your/dataset.arff"); Instances data = source.getDataSet(); // 设置类别属性 data.setClassIndex(data.numAttributes() - 1); // 构建决策树模型 J48 tree = new J48(); tree.buildClassifier(data); // 进行预测 Instance newInstance = data.instance(0); // 假设要预测第一个实例 double predictedClass = tree.classifyInstance(newInstance); System.out.println("预测结果:" + predictedClass); } } ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值