大概介绍
在之前为了处理Iris花的分类算法,写了一个决策树算法,但这个算法局限性比较大,只能用于那一种情况,所以为了增强代码的复用性,在之前的基础上修改了算法,增强了复用性。略有遗憾的是,由于很多模块,比如快速排序,是为了处理Iris花专门写的,所以现在要进行一些转换才行,一定程度上增加了代码的复杂度和冗余度。这个是以后值得注意的地方,即在设计之初就应该注意到代码的复用问题,尽可能早的定义好通用接口,这样就能增强模块的复用度。 修改以后的程序使用起来方便了好多,在使用之前要定义一套变量ProjectInfo,里面包括决策树的深度等参量,用于调节决策树的生成。
举个例子吧,我想用我的程序来分离正态分布随机变量和均匀随机变量。这是我实现用matlab生成的变量,存在excel里面。
其中黄点和绿点是均匀分布的随机变量,黄点是训练集,绿点是检验集。而蓝点红点是正太分布随机变量,其中红点是训练集,而蓝点是检验集。可以看到这些正态分布的点分成3坨,有3个中心,分别是(5,15),(15,5)以及(10,10)。其中中间那坨的方差要大一些。我想用我的算法来将这两类分开来,那么首先就要定义一个ProjectInfo对象,用来控制决策树的一些相关变量,比如深度等。ProjectInfo的参数有如下几个:
int AttrNum ; //属性值的数目,如有2种属性,则AttrNum=2 ,在这个例子中,原始属性为2种,分别为x,y但是看图可以明显得知用斜线可以更好地分类,我这里又加了两个额外属性,分别为x+y,x-y。此时AttrNum就是4.看官要是高兴,也可以加x+2y,x+3y这些额外属性。当然了,要记得更改这个值
int[] AttrType ; //Attributes[i]表示该值的属性,double为0,boolean为1。这个例子中x,y,x+y,x-y均为double,所以为{0,0,0,0}
String[] SqlColName; //SQL数据库中列名. 这个是为了方便sql读取的,存放sql数据库中的列名
String[] AttrDescription ; //对Arrtibute的描述,SqlColNmae中的名字要与数据库严格对应,而这里则不需要,自己能看懂就行了,属于注释类的
int dataNodeNum ; //数据集的数目。指的是训练集的数目。我这个例子中的训练集总共2000个点,其中1200个均匀分布,800个高斯分布。所以dataNodeNum=2000
String[] classNum ; //数据分属的各个类的名字,如果要从数据库导入,则必须跟数据库中的类名一样,如果不是,则随意
int RuleDeep ; //决策树的深度。如果节点的深度大于该值,则不再分裂
int NodeLimitNum ; //若节点内数据量小于等于该值,则不再分裂
double NodeLimitEntropy ; //若节点内数据熵小于等于该值,则不再分裂
double NodeDeltaEntropy ; //若节点分裂前后数据熵变化小于等于该值,则不再分裂
定义之后,需要训练集以及检验集。考虑到通用性,不再写通用的函数,只需要保证最后格式就行。具体格式为:
1.送入RuleNode函数的必须要是double[][] example格式
2.example.length=dataNodeNum
3.example[i].length=AttrNum+3 其中example[i][0]存放数据序号,从0-(dataNodeNum-1) ; example[i][1]=ClassType,其值从0-(classNum.length-1)。 在这个例子中,由于只有两类点,所以其值为0,1.如果是之前的Iris,其值范围就是0-2 ; example[i][2]=tempClassType. 值只能是0或者1. ;example[i][3-(2+AttrNum)]存放各个属性值。
当然为了方便,还是准备了从sql以及excel中读取数组的接口。当然读取过来的不一定能直接用,需要自己修改。你也可以自己写接口把数组准备好。总之只要格式对就行了。
准备好数组和ProjectInfo以后,生成一个RuleNode的节点root。RuleNode构造函数的接口为
public RuleNode(double[][] input,int deep,String tag,ProjectInfo pInfo)
其中input为训练集,deep为深度,初始深度为0,tag记录节点位置,初始值为“0”。所以例子中的调用方法为
RuleNode root=new RuleNode(trainset,0,”0”,pInfo) ;
之后可以用root.PrintRule(String filename,RuleNode root, ProjectInfo pInfo) 来打印决策树的规则,输出在桌面上。
rule.getNodeNum(root)则可以观察该决策树总共有多少个节点
为了检验决策树的错误率,可以使用Estimate类。
Estimate es=new Estimate(rule,trainset,pInfo) ;
System.out.println(“训练集的错误率为:”+es.ErrRatio);
在这个例子中,我设的参数值为:
int AttrNum=4 ;
int[] AttrType={0,0,0,0} ;
String[] SqlColName={“x”,”y”,”x+y”,”x-y”} ;
String[] AttrDescription={“x”,”y”,”x+y”,”x-y”} ;
int dataNodeNum=2000 ;
String[] classNum={“Average”,”Normal”} ;
int RuleDeep=8 ;
int NodeLimitNum=10 ;
double NodeLimitEntropy=0 ;
double NodeDeltaEntropy=0 ;
最后得到的结果是规则数共有131个节点,训练集的错误率为15.65%,检验集的错误率为19.2%。生成的规则数如下:
x+y<=14.244
| x+y<=12.206
| | y<=4.7925: Class0
| | y>4.7925
| | | y<=4.8215: Class1
| | | y>4.8215: Class0
| x+y>12.206
| | x<=5.204: Class0
| | x>5.204
| | | x<=7.156499999999999
| | | | x<=5.2625: Class1
| | | | x>5.2625
| | | | | x<=5.9145: Class0
| | | | | x>5.9145: Class0
| | | x>7.156499999999999
| | | | x+y<=13.5105: Class0
| | | | x+y>13.5105
| | | | | x+y<=13.524999999999999: Class1
| | | | | x+y>13.524999999999999
| | | | | | x+y<=13.92: Class0
| | | | | | x+y>13.92: Class0
x+y>14.244
| x+y<=23.326500000000003
| | x<=3.101
| | | y<=16.8975
| | | | x<=0.8965000000000001: Class0
| | | | x>0.8965000000000001
| | | | | y<=12.908000000000001: Class0
| | | | | y>12.908000000000001
| | | | | | x<=0.9964999999999999: Class1
| | | | | | x>0.9964999999999999
| | | | | | | x-y<=-10.7015
| | | | | | | | x+y<=16.7165: Class0
| | | | | | | | x+y>16.7165: Class0
| | | | | | | x-y>-10.7015: Class1
| | | y>16.8975: Class0
| | x>3.101
| | | y<=2.8085
| | | | x-y<=14.558
| | | | | x<=16.951
| | | | | | y<=1.1775: Class0
| | | | | | y>1.1775
| | | | | | | x-y<=13.814499999999999
| | | | | | | | x+y<=14.5165: Class1
| | | | | | | | x+y>14.5165: Class0
| | | | | | | x-y>13.814499999999999: Class1
| | | | | x>16.951: Class1
| | | | x-y>14.558
| | | | | y<=2.1395: Class0
| | | | | y>2.1395: Class0
| | | y>2.8085
| | | | x+y<=16.6265
| | | | | y<=12.838000000000001
| | | | | | x<=6.4335
| | | | | | | x-y<=-7.651499999999999: Class0
| | | | | | | x-y>-7.651499999999999: Class0
| | | | | | x>6.4335
| | | | | | | x<=9.96
| | | | | | | | x-y<=3.2985: Class1
| | | | | | | | x-y>3.2985: Class1
| | | | | | | x>9.96
| | | | | | | | x-y<=9.027999999999999: Class0
| | | | | | | | x-y>9.027999999999999: Class1
| | | | | y>12.838000000000001: Class1
| | | | x+y>16.6265
| | | | | x-y<=13.459
| | | | | | x+y<=22.872999999999998
| | | | | | | x-y<=-8.513
| | | | | | | | x-y<=-10.426: Class1
| | | | | | | | x-y>-10.426: Class1
| | | | | | | x-y>-8.513
| | | | | | | | x<=7.8805: Class1
| | | | | | | | x>7.8805: Class1
| | | | | | x+y>22.872999999999998
| | | | | | | x+y<=23.162499999999998
| | | | | | | | x<=17.2825: Class0
| | | | | | | | x>17.2825: Class1
| | | | | | | x+y>23.162499999999998
| | | | | | | | x<=5.6205: Class0
| | | | | | | | x>5.6205: Class1
| | | | | x-y>13.459
| | | | | | x+y<=21.8705
| | | | | | | x+y<=21.7155: Class0
| | | | | | | x+y>21.7155: Class1
| | | | | | x+y>21.8705: Class0
| x+y>23.326500000000003
| | x+y<=27.7235
| | | y<=17.5545
| | | | y<=10.692499999999999
| | | | | x+y<=24.683999999999997
| | | | | | x+y<=23.731499999999997: Class0
| | | | | | x+y>23.731499999999997
| | | | | | | x+y<=23.7855: Class1
| | | | | | | x+y>23.7855
| | | | | | | | x+y<=24.435000000000002: Class0
| | | | | | | | x+y>24.435000000000002: Class0
| | | | | x+y>24.683999999999997
| | | | | | x+y<=26.735500000000002: Class0
| | | | | | x+y>26.735500000000002
| | | | | | | x+y<=26.7965: Class1
| | | | | | | x+y>26.7965: Class0
| | | | y>10.692499999999999
| | | | | x+y<=25.2045
| | | | | | x-y<=-2.4989999999999997
| | | | | | | y<=14.745000000000001
| | | | | | | | x<=11.157499999999999: Class0
| | | | | | | | x>11.157499999999999: Class1
| | | | | | | y>14.745000000000001
| | | | | | | | y<=15.202: Class1
| | | | | | | | y>15.202: Class0
| | | | | | x-y>-2.4989999999999997
| | | | | | | x<=12.005500000000001
| | | | | | | | x<=10.674: Class0
| | | | | | | | x>10.674: Class1
| | | | | | | x>12.005500000000001
| | | | | | | | x<=12.3055: Class0
| | | | | | | | x>12.3055: Class1
| | | | | x+y>25.2045
| | | | | | x<=10.908000000000001: Class0
| | | | | | x>10.908000000000001
| | | | | | | x+y<=25.803: Class0
| | | | | | | x+y>25.803
| | | | | | | | x<=11.1095: Class1
| | | | | | | | x>11.1095: Class0
| | | y>17.5545: Class0
| | x+y>27.7235
| | | y<=12.2935
| | | | x<=17.162: Class1
| | | | x>17.162: Class0
| | | y>12.2935: Class0
附录:代码
ProjectInfo
package Classification;
public class ProjectInfo {
public int AttrNum ; //属性值的数目,如有2种属性,则AttrNum=2
public int[] AttrType ; //Attributes[i]表示该值的属性,double为0,boolean为1
public String[] SqlColName; //SQL数据库中列名
public String[] AttrDescription ; //对Arrtibute的描述
public int dataNodeNum ; //数据集的数目
public String[] classNum ; //数据分属的各个类的名字
//正常情况下,double[i][0]存储SetNum,double[i][1]存储Type,double[i][1]存储tempType
//double[i][2-(2+AttrNum-1)]存储属性值
//public int addAttrNum ; //自定义属性数目 ,默认为0 若为>0则启动自定义属性函数
public int RuleDeep ;
public int NodeLimitNum ; //若节点内数据数小于等于该值,则不再分裂
public double NodeLimitEntropy ; //若节点内数据熵小于等于该值,则不再分裂
public double NodeDeltaEntropy ; //若节点内数据熵小于等于该值,则不再分裂
public ProjectInfo(){ //空对象
this.AttrNum=0 ;
this.AttrType=null ;
this.SqlColName=null ;
this.AttrDescription=null ;
this.dataNodeNum=0 ;
this.classNum=null ;
//this.addAttrNum=0 ;
this.RuleDeep=100 ;
this.NodeLimitNum=1 ;
this.NodeLimitEntropy=0 ;
this.NodeDeltaEntropy=1 ;
}
public ProjectInfo(int AttrNum,int[] AttrType,String[] SqlColName,
String[] AttrDescription,int dataNodeNum,String[] classNum,
int RuleDeep,int NodeLimitNum,
double NodeLimitEntropy,double NodeDeltaEntropy){
this.AttrNum=AttrNum ;
if(AttrType.length!=this.AttrNum) System.out.println("ERROR:ProjectInfo_AttrType输入值数目错误");
this.AttrType=AttrType ;
if(SqlColName.length!=this.AttrNum) System.out.println("ERROR:ProjectInfo_SqlColName输入值数目错误");
this.SqlColName=SqlColName ;
if(AttrDescription.length!=this.AttrNum) System.out.println("ERROR:ProjectInfo_AttrDescription输入值数目错误");
this.AttrDescription=AttrDescription ;
this.dataNodeNum=dataNodeNum ;
this.classNum=classNum ;
//this.addAttrNum=addAttrNum ;
this.RuleDeep=RuleDeep ;
this.NodeLimitNum=NodeLimitNum ;
this.NodeLimitEntropy=NodeLimitEntropy ;
this.NodeDeltaEntropy=NodeDeltaEntropy ;
}
public static void main(String[] args){
int AttrNum=4 ;
int[] AttrType={0,0,0,0} ;
String[] SqlColName={"x","y","x+y","x-y"} ;
String[] AttrDescription={"x","y","x+y","x-y"} ;
int dataNodeNum=2000 ;
String[] classNum={"Average","Normal"} ;
//int addAttrNum=0 ;
int RuleDeep=8 ;
int NodeLimitNum=10 ;
double NodeLimitEntropy=0 ;
double NodeDeltaEntropy=0 ;
ProjectInfo pInfo=new ProjectInfo(AttrNum,AttrType,SqlColName,
AttrDescription,dataNodeNum,classNum,RuleDeep,
NodeLimitNum,NodeLimitEntropy,NodeDeltaEntropy) ;
FileIO ep=new FileIO() ;
int len=2000 ;
int width=3 ; //excel中的列数
int[] nodea={0,0} ;
int[] nodeb={width-1,len-1} ;
double[][] get1=ep.getArray("origin",0,nodea, nodeb) ;
double[][] trainset=new double[len][pInfo.AttrNum+3] ;
for(int i=0;i<len;i++){
trainset[i][0]=i+1 ;
trainset[i][1]=get1[i][2] ;
trainset[i][2]=get1[i][2] ;
trainset[i][3]=get1[i][0] ;
trainset[i][4]=get1[i][1] ;
trainset[i][5]=get1[i][0]+get1[i][1] ;
trainset[i][6]=get1[i][0]-get1[i][1] ;
}
len=1000 ;
width=3 ;
nodeb[0]=width-1 ;
nodeb[1]=len-1 ;
double[][] get2=ep.getArray("origin",1,nodea, nodeb) ;
double[][] examset=new double[len][pInfo.AttrNum+3] ;
for(int i=0;i<len;i++){
examset[i][0]=i+1 ;
examset[i][1]=get2[i][2] ;
examset[i][2]=get2[i][2] ;
examset[i][3]=get2[i][0] ;
examset[i][4]=get2[i][1] ;
examset[i][5]=get2[i][0]+get2[i][1] ;
examset[i][6]=get2[i][0]-get2[i][1] ;
}
RuleNode root=new RuleNode(trainset,0,"0",pInfo) ;
root.PrintRule("rule", root, pInfo);
RuleNode rule=root ;
System.out.println("规则数的节点数为 "+rule.getNodeNum(root));
Estimate es=new Estimate(rule,trainset,pInfo) ;
System.out.println("训练集的错误率为:"+es.ErrRatio);
es=new Estimate(rule,examset,pInfo) ;
System.out.println("检验集的错误率为:"+es.ErrRatio);
//es=new Estimate(rule,trainset,pInfo) ;
//System.out.println("训练集的错误率为:"+es.ErrRatio);
}
}
RuleNode
package Classification;
public class RuleNode {
public int deep ;
public double formerEntropy ;
public double[][] datalist ;
public String tag ;
public int nodeType=-1 ;
//public int minSetNum=1 ; //一个有效枝节点的最少数据数 若少于这个,则不再分裂
public int divideType=-1 ;
public double valveValue=-1 ;
public double deltaEntropy=1 ;
public RuleNode leftChild=null ;
public RuleNode rightChild=null ;
public double laterEntropy=2 ;
public RuleNode(double[][] input,int deep,String tag,ProjectInfo pInfo){
//System.out.println("正在建立第 "+tag+" 号节点"); //used for debug
//System.out.println("第 "+tag+" 号节点数组数目:"+input.length);
this.deep=deep ;
this.tag=tag ;
this.datalist=input ;
this.formerEntropy=getDataListEntropy(input) ; //undefined
this.nodeType=-1 ;
if ((this.deep>pInfo.RuleDeep)||(this.datalist.length<=pInfo.NodeLimitNum)||(this.formerEntropy<=pInfo.NodeLimitEntropy)){
//深度过大或者点数过小,或者分来的数组足够纯净,则不再分类
this.leftChild=this.rightChild=null ;
int temp=decideType(this.datalist) ; //undefined
if ((temp==0)||(temp==1)) this.nodeType=temp ;
else System.out.println("ERROR:函数decideType输出值不合法") ;
}else{
Hunt hunt=new Hunt(input,pInfo) ; //undefined
this.divideType=hunt.type ;
this.valveValue=hunt.value_value ;
this.laterEntropy=hunt.min_entropy ;
this.deltaEntropy=this.formerEntropy-this.laterEntropy ;
if (this.deltaEntropy<pInfo.NodeDeltaEntropy){
this.leftChild=this.rightChild=null ; //if deltaEntropy<0.05 or deep>5 no longer continue
int temp=decideType(this.datalist) ;
if ((temp==0)||(temp==1)) this.nodeType=temp ;
else System.out.println("ERROR:函数decideType输出值不合法") ;
}else{
//System.out.println("tag1") ; //used for debug
double[][] leftList=Divide(input,this.divideType,this.valveValue,0,pInfo) ;
double[][] rightList=Divide(input,this.divideType,this.valveValue,1,pInfo) ;
//if (tag=="001") System.out.println(leftChild==null) ;//used for debug
//if (leftList==null) System.out.println(tag+"节点的左子树为空") ; else System.out.println(tag+"节点的左子树长"+leftList.length) ; //used for debug
//if (rightList==null) System.out.println(tag+"节点的右子树为空") ; else System.out.println(tag+"节点的右子树长"+rightList.length) ; //used for debug
//if ((leftList==null)||(rightList==null)) this.leftChild=this.rightChild=null ;
if ((leftList.length==0)||(rightList.length==0)) {
this.leftChild=this.rightChild=null ;
int temp=decideType(this.datalist) ;
if ((temp==0)||(temp==1)) this.nodeType=temp ;
else System.out.println("ERROR:函数decideType输出值不合法") ;
}
else{
this.leftChild=new RuleNode(leftList,deep+1,tag+'0',pInfo) ;
this.rightChild=new RuleNode(rightList,deep+1,tag+'1',pInfo) ;
}
}
}
}
public static double[][] Divide(double[][] input,int attribute,double valve,int methodtype,ProjectInfo pInfo){
double[][] rs=null ;
//通过attribute value type来将input分成两部分
if (methodtype==0){ //此处为methodtype=1时的情况,也就是attr value<valve的情况
int num=0 ;
for(int i=0;i<input.length;i++){
if ((attribute>=0)&&(attribute<pInfo.AttrNum)){
if (input[i][attribute+3]<=valve) num++ ;
}else System.out.println("ERROR:The value of attribute value illegal");
}
rs=new double[num][pInfo.AttrNum+3] ;
int index=0 ;
for(int i=0;i<input.length;i++){
if ((attribute>=0)&&(attribute<pInfo.AttrNum)){
if (input[i][attribute+3]<=valve) rs[index++]=input[i] ;
}else System.out.println("ERROR:The value of attribute value illegal");
}
return rs ;
}else if(methodtype==1){
int num=0 ;
for(int i=0;i<input.length;i++){
if ((attribute>=0)&&(attribute<pInfo.AttrNum)){
if (input[i][attribute+3]>valve) num++ ;
}else System.out.println("ERROR:The value of attribute value illegal");
}
rs=new double[num][pInfo.AttrNum+3] ;
int index=0 ;
for(int i=0;i<input.length;i++){
if ((attribute>=0)&&(attribute<pInfo.AttrNum)){
if (input[i][attribute+3]>valve) rs[index++]=input[i] ;
}else System.out.println("ERROR:The value of attribute value illegal");
}
return rs ;
}else System.out.println("ERROR:RuleNode_Divide_methodtype value illegal");
return rs ;
}
public int getNodeNum(RuleNode node){
if (node.nodeType==-1){
int num1=getNodeNum(node.leftChild) ;
int num2=getNodeNum(node.rightChild) ;
return num1+num2+1 ;
}else{
return 1 ;
}
}
public void PrintRule(String filename,RuleNode node,ProjectInfo pInfo){
String out="" ;
for(int i=0;i<node.deep;i++) out+="| " ;
if (node.leftChild.nodeType==-1) {
FileIO.PrintTxtln(filename, out+pInfo.AttrDescription[node.divideType]+"<="+node.valveValue);
PrintRule(filename,node.leftChild,pInfo) ;
}else{
FileIO.PrintTxtln(filename, out+pInfo.AttrDescription[node.divideType]+"<="+node.valveValue+": Class"+node.leftChild.nodeType);
}
if (node.rightChild.nodeType==-1) {
FileIO.PrintTxtln(filename, out+pInfo.AttrDescription[node.divideType]+">"+node.valveValue);
PrintRule(filename,node.rightChild,pInfo) ;
}else{
FileIO.PrintTxtln(filename, out+pInfo.AttrDescription[node.divideType]+">"+node.valveValue+": Class"+node.rightChild.nodeType);
}
}
private static double getDataListEntropy(double[][] input){ //根据输入的二维数组确定datalist的熵
DataProperty dp=new DataProperty() ;
double rs_entropy=-1 ;
//通过tempType的值来计算irisdata数组的熵
//tempType只有3个值,0表示类1,1表示类2,-1表示其他类 一般用于表示异常
int num1=0,num2=0 ;
for(int i=0;i<input.length;i++){
if(input[i][2]==0) num1++ ;
if(input[i][2]==1) num2++ ;
}
rs_entropy=dp.getEntropy(num1, num2) ;
return rs_entropy ;
}
private static int decideType(double[][] input){
int rs=-1 ;
int num0=0,num1=0 ;
for(int i=0;i<input.length;i++){
if (input[i][2]==0) num0++ ;
if (input[i][2]==1) num1++ ;
}
if (num0<num1) rs=1 ; //有条件的话可以吧num0=num1时node的归属用随机数来实现
else rs=0 ;
return rs ;
}
}
Hunt
package Classification;
public class Hunt {
public double min_entropy ;
public double value_value ;
public int type ;
public Hunt(double[][] dataset,ProjectInfo pInfo){
//1. calculate the entropy of initial dataset
//2. find best attritube from 4
double[][] rs=new double[pInfo.AttrNum][2] ;
int mintype=-1 ;
double minentropy=1 ;
double valve_value=-1 ;
for(int i=0;i<pInfo.AttrNum;i++){
if (pInfo.AttrType[i]==0) rs[i]=FindBestValve(preDeal(dataset,i,pInfo),pInfo) ;
else if (pInfo.AttrType[i]==1) {rs[i]=BoolAttr(dataset,i) ;} ///////// undefined
else System.out.println("ERROR:Hunt_Hunt_pInfo.AttrType["+i+"]设定有问题");
//rs[i][0]=entropy rs[i][1]=valve
if(rs[i][0]<minentropy){
minentropy=rs[i][0] ;
valve_value=rs[i][1] ;
mintype=i ;
}
}
//3. find the best one and output
this.min_entropy=minentropy ;
this.value_value=valve_value ;
this.type=mintype ;
}
private static double[] BoolAttr(double[][] input,int Attr){
int total_len=input.length ;
int len1=0,len2=0 ;
for(int i=0;i<total_len;i++){
if (input[i][3+Attr]==0) len1++ ;
if (input[i][3+Attr]==1) len2++ ;
}
if((len1+len2)!=total_len) System.out.println("ERROR:Hunt_BoolAttr_bool变量中有异值") ;
DataProperty dp=new DataProperty() ;
double[] rs=new double[2] ;
rs[0]=dp.getEntropy(len1, len2) ;
rs[1]=0.5 ;
return rs ;
}
private static double[][] preDeal(double[][] dataset,int attr,ProjectInfo pInfo){ //transfer IrisData[] to int[][] to fit the followign processing
if ((attr<pInfo.AttrNum)&&(attr>=0)){
double[][] rs=new double[dataset.length][3] ; //3 attributes:Attribute Value,Number,,tempType
for(int i=0;i<dataset.length;i++){
rs[i][0]=dataset[i][attr+3] ;
rs[i][1]=dataset[i][0] ;
rs[i][2]=dataset[i][2] ; //ATTENTION the taken value is tempTyoe!
}
return rs ;
}else {System.out.println("ERROR:Hunt_preDeal_type输入值不正确");return null ;}
}
private static double[] FindBestValve(double[][] input,ProjectInfo pInfo){
//要考虑Type的多值性,最好只有两个值
//print(input) ; // used for debug
double[][] sorted=QuickSort(input,0,input.length-1) ; //1st step:sort the input array
//接下来应该要在不同值区间内循环,挑一个熵值最小的。
double min_entropy=2 ;
double valve_value=-1 ;
//System.out.println(sorted==null); // used for debug
for(int i=0;i<sorted.length-1;i++){
// calculate the entropy of the division whose valve is between i and i+1
if (sorted[i][0]!=sorted[i+1][0]){ //避免在两个相同值之间分析的情况
double temp_entropy=CalculateEntropy(sorted,i) ;
if (temp_entropy<min_entropy){
min_entropy=temp_entropy ;
valve_value=(sorted[i][0]+sorted[i+1][0])/2 ;
}
}
}
double[] rs=new double[2] ;
rs[0]=min_entropy ;
rs[1]=valve_value ;
return rs ;
}
private static double CalculateEntropy(double[][] sorted,int i) { //can only deal with the data which have only two classes
DataProperty dp=new DataProperty() ; //initialization of dataproperty
double rs_entropy=-1 ;
//double tagclass=sorted[0][2] ;
int num1=0 ;
int num2=0 ;
for(int x=0;x<i+1;x++){
if(sorted[x][2]==0) num1++ ;
else if(sorted[x][2]==1) num2++ ;
else System.out.println("ERROR from CalculateEntropy: the value of tempType of a item is -1");
}
double entropy1=dp.getEntropy(num1,num2) ;
int tnum1=num1+num2 ; //total number of the former sequence
num1=0 ;
num2=0 ;
for(int x=i+1;x<sorted.length;x++){
if(sorted[x][2]==0) num1++ ;
else if(sorted[x][2]==1) num2++ ;
else System.out.println("ERROR from CalculateEntropy: the value of tempType of a item is -1");
}
double entropy2=dp.getEntropy(num1,num2) ;
int tnum2=num1+num2 ;
rs_entropy=(entropy1*tnum1+entropy2*tnum2)/(tnum1+tnum2) ;
return rs_entropy ;
}
private static double[][] QuickSort(double[][] input,int low,int high){
if(low>=high) return null ;
int first=low ;
int last=high ;
double[] key=input[low] ;
while(first<last){
while((first<last)&&(input[last][0]>=key[0])) --last ;
input[first]=input[last] ;
while((first<last)&&(input[first][0]<=key[0])) ++first ;
input[last]=input[first] ;
}
input[first]=key ;
double[][] res1,res2 ;
if (first-1>low) {res1=QuickSort(input,low,first-1) ;}
else if(first-1==low) {double[][] temp={input[low]} ;res1=temp ;}
else{res1=null ;}
if(high>first+1){res2=QuickSort(input,first+1,high) ;}
else if(high==first+1){double[][] temp={input[high]} ;res2=temp ;}
else{res2=null ;}
double[][] finalres ;
finalres=Combine(res1,res2,key) ;
return finalres ;
}
private static double[][] Combine(double[][] res1,double[][] res2,double[] key){
int len1,len2 ;
if(res1==null) len1=0 ;
else len1=res1.length ;
if(res2==null) len2=0 ;
else len2=res2.length ;
double[][] res=new double[len1+len2+1][3] ;
int index=0 ;
for(int i=0;i<len1;i++) res[index++]=res1[i] ;
res[index++]=key ;
for(int i=0;i<len2;i++) res[index++]=res2[i] ;
return res ;
}
}
Estimate
package Classification;
import java.util.*; ;
public class Estimate {
List<double[]> list0 ;
List<double[]> list1 ;
double[][] array0 ;
double[][] array1 ;
RuleNode examtree ;
double ErrRatio ;
public Estimate(RuleNode rule,double[][] examset,ProjectInfo pInfo){
this.list0=new ArrayList<double[]>() ;
this.list1=new ArrayList<double[]>() ;
this.examtree=examTree(rule,examset,pInfo) ;
this.ErrRatio=getErrRatio(this.list0,this.list1) ;
this.array0=convert(list0,pInfo) ;
this.array1=convert(list1,pInfo) ;
}
private double getErrRatio(List list0,List list1){
double len1=list0.size() ;
double len2=list1.size() ;
double errnum1=0,errnum2=0 ;
for(int i=0;i<len1;i++){
double[] temp=(double[])list0.get(i) ;
if (temp[2]==1) errnum1++ ;
}
for(int i=0;i<len2;i++){
double[] temp=(double[])list1.get(i) ;
if(temp[2]==0) errnum2++ ;
}
double erratio=(errnum1+errnum2)/(len1+len2) ;
return erratio ;
}
private double[][] convert(List list,ProjectInfo pInfo){
int len=list.size() ;
double[][] rs=new double[len][3+pInfo.AttrNum] ;
for(int i=0;i<len;i++){
rs[i]=(double[])list.get(i) ;
}
return rs ;
}
private RuleNode examTree(RuleNode node,double[][] data,ProjectInfo pInfo){
node.datalist=data ;
node.formerEntropy=getDataListEntropy(data) ;
if (node.nodeType==-1) { //this node is not a leaf node
double[][] left=RuleNode.Divide(data, node.divideType, node.valveValue, 0,pInfo) ;
double[][] right=RuleNode.Divide(data, node.divideType, node.valveValue, 1,pInfo) ;
if (left.length==0) node.leftChild=null ;
else node.leftChild=examTree(node.leftChild,left,pInfo) ;
if(right.length==0) node.rightChild=null ;
else node.rightChild=examTree(node.rightChild,right,pInfo) ;
return node ;
}else{ // this node is a leaf node
node.leftChild=null ;
node.rightChild=null ;
int len=node.datalist.length ; //将判定为0或者1的类调入到list0,list1中
int num=0 ;
if (node.nodeType==0) {
for(int i=0;i<len;i++) list0.add(node.datalist[i]) ;
}
if (node.nodeType==1) {
for(int i=0;i<len;i++) list1.add(node.datalist[i]) ;
}
return node ;
}
}
private static double getDataListEntropy(double[][] input){ //根据输入的二维数组确定datalist的熵
DataProperty dp=new DataProperty() ;
double rs_entropy=-1 ;
//通过tempType的值来计算irisdata数组的熵
//tempType只有3个值,0表示类1,1表示类2,-1表示其他类 一般用于表示异常
int num1=0,num2=0 ;
for(int i=0;i<input.length;i++){
if(input[i][2]==0) num1++ ;
if(input[i][2]==1) num2++ ;
}
rs_entropy=dp.getEntropy(num1, num2) ;
return rs_entropy ;
}
}
FileOP
package Classification;
//import java.io.File;
import java.io.* ;
import jxl.Workbook;
import jxl.write.Label;
import jxl.write.WritableSheet;
import jxl.write.WritableWorkbook;
import jxl.Cell ;
import jxl.Sheet ;
public class FileIO {
public FileIO(){
}
public void PrintDoubleArray(double[][] input,String filename,ProjectInfo pInfo){
try{
String rootname="C:\\Users\\multiangle\\Desktop\\" ;
String path=rootname+filename+".xls" ;
File file=new File(path) ;
WritableSheet sheet ;
WritableWorkbook book ;
if (file.exists()) {
Workbook wb=Workbook.getWorkbook(file) ;
book=Workbook.createWorkbook(file, wb) ;
int sheetnum=book.getNumberOfSheets() ;
sheet=book.createSheet("第"+sheetnum+"页", sheetnum) ;
System.out.println("正在第"+sheetnum+"页打印double数组");
}else {
book=Workbook.createWorkbook(new File(path)) ;
sheet=book.createSheet("第0页", 0) ;
System.out.println("正在第0页打印double数组");
}
//System.out.println("已获取到需要的表单");
String[] name=new String[3+pInfo.AttrNum] ;
name[0]="SetNum" ;
name[1]="Type" ;
name[2]="tempType" ;
for(int i=0;i<pInfo.AttrNum;i++){
name[3+i]=pInfo.AttrDescription[i] ;
}
for(int i=0;i<3+pInfo.AttrNum;i++){
Label temp=new Label(i,0,name[i]) ;
sheet.addCell(temp);
}
int len=input.length ;
int row=1 ;
for(int i=0;i<len;i++){
for(int j=0;j<3+pInfo.AttrNum;j++){
Label temp=new Label(j,row+i,String.valueOf(input[i][j])) ;
sheet.addCell(temp);
}
}
book.write() ;
book.close();
}catch(Exception e){
System.out.println(e) ;
System.out.println("ERROR:ExcelPrint") ;
}
}
public void PrintDoubleArray(double[][] input,String filename,String description,ProjectInfo pInfo){
try{
String rootname="C:\\Users\\multiangle\\Desktop\\" ;
String path=rootname+filename+".xls" ;
File file=new File(path) ;
WritableSheet sheet ;
WritableWorkbook book ;
if (file.exists()) {
Workbook wb=Workbook.getWorkbook(file) ;
book=Workbook.createWorkbook(file, wb) ;
int sheetnum=book.getNumberOfSheets() ;
sheet=book.createSheet("第"+sheetnum+"页", sheetnum) ;
System.out.println("正在第"+sheetnum+"页打印double数组");
}else {
book=Workbook.createWorkbook(new File(path)) ;
sheet=book.createSheet("第0页", 0) ;
System.out.println("正在第0页打印double数组");
}
//System.out.println("已获取到需要的表单");
Label descrip=new Label(0,0,description) ;
sheet.addCell(descrip);
String[] name=new String[3+pInfo.AttrNum] ;
name[0]="SetNum" ;
name[1]="Type" ;
name[2]="tempType" ;
for(int i=0;i<pInfo.AttrNum;i++){
name[3+i]=pInfo.AttrDescription[i] ;
}
for(int i=0;i<3+pInfo.AttrNum;i++){
Label temp=new Label(i,1,name[i]) ;
sheet.addCell(temp);
}
int len=input.length ;
int row=2 ;
for(int i=0;i<len;i++){
for(int j=0;j<3+pInfo.AttrNum;j++){
Label temp=new Label(j,row+i,String.valueOf(input[i][j])) ;
sheet.addCell(temp);
}
}
book.write() ;
book.close();
}catch(Exception e){
System.out.println(e) ;
System.out.println("ERROR:ExcelPrint") ;
}
}
public double[][] getArray(String filename,int sheetnum,int[] nodea,int[] nodeb){
int left=nodea[0] ;
int top=nodea[1] ;
int right=nodeb[0] ;
int bottom=nodeb[1] ;
int len=bottom-top+1 ;
int width=right-left+1 ;
double[][] rs=new double[len][width] ;
String root="C:\\Users\\multiangle\\Desktop\\" ; //基本目录为桌面
String path=root+filename+".xls" ;
File file=new File(path) ;
if (!file.exists()) {System.out.println("ERROR:ExcelIO_getArray_File not exists");return null ;}
else{
try{
Workbook book=Workbook.getWorkbook(file) ;
Sheet sheet=book.getSheet(sheetnum) ;
for(int i=0;i<len;i++){
for(int j=0;j<width;j++){
Cell cell=sheet.getCell(j,i) ;
double temp=Double.parseDouble(cell.getContents()) ;
rs[i][j]=temp ;
}
}
return rs ;
}catch(Exception e){
System.out.println("ERROR:ExcelIO_getArray");
System.out.println(e);
return null ;
}
}
}
public static void PrintTxtln(String filename,String line){
try{
String root="C:\\Users\\multiangle\\Desktop\\" ; //基本目录为桌面
String path=root+filename+".txt" ;
FileWriter out=new FileWriter(path,true) ;
out.write(line+"\r\n");
out.close();
}catch(Exception e){
System.out.println(e);
}
}
}
dataSqlGet
package Classification;
import java.sql.* ;
import Classification.ProjectInfo ;
public class dataSqlGet {
public double[][] dataset ;
public dataSqlGet(String DatabaseName,ProjectInfo pInfo){
ResultSet rs=getResultSet(DatabaseName) ;
this.dataset=ResultDeal(rs,pInfo) ;
}
private static ResultSet getResultSet(String DatabaseName){
String JDriver="com.microsoft.sqlserver.jdbc.SQLServerDriver";//SQL数据库引擎
String connectDB="jdbc:sqlserver://127.0.0.1:1433;DatabaseName=multiangle";//数据源
try{
Class.forName(JDriver);//加载数据库引擎,返回给定字符串名的类
System.out.println("数据库驱动成功");
}catch(ClassNotFoundException e){ //e.printStackTrace();
System.out.println("加载数据库引擎失败");
System.out.println(e);
}
ResultSet rs ;
try{
String user="sa" ;
String password="admin" ;
Connection con=DriverManager.getConnection(connectDB,user,password);
System.out.println("数据库连接成功");
Statement stmt=con.createStatement() ;
//String query="select ROW_NUMBER()over(order by class)as row,* from dbo.[bezdekIris.data]" ;
String query="select ROW_NUMBER()over(order by class)as row,* from "+DatabaseName ;
rs=stmt.executeQuery(query) ;
return rs ;
}catch(SQLException e){
System.out.println(e) ;
System.out.println("数据库内容读取失败");
return null ;
}
}
public static double[][] ResultDeal(ResultSet rs,ProjectInfo pInfo){
System.out.println("SqlColName.length"+pInfo.dataNodeNum) ;
try {
int len=pInfo.dataNodeNum ;
System.out.println("len: "+len);
double[][] dataset=new double[pInfo.dataNodeNum][pInfo.AttrNum+3] ;
int num=0 ;
while((num<len)&&(rs.next())){
dataset[num][0]=Integer.parseInt(rs.getString("row")) ;
String name=rs.getString("Class") ; //这个在录入数据的时候要特别注意
int type=-1 ;
int namelen=pInfo.classNum.length ;
for(int i=0;i<namelen;i++){
if(name.equals(pInfo.classNum[i])) type=i ;
}
dataset[num][1]=type ; //Type
dataset[num][2]=-1 ; //tempType
//0-> SetNum ;1->Type;2->tempType;3-(3+attr)->attrvalue
for(int i=0;i<pInfo.SqlColName.length;i++){
dataset[num][i+3]=Double.parseDouble(rs.getString(pInfo.SqlColName[i])) ;
}
num++ ;
//System.out.println(setnum+" "+SL+" "+SW+" "+PL+" "+PW+" "+type) ;
}
System.out.println("ResultSet 解析完毕");
return dataset ;
} catch (SQLException e) {
System.out.println("ResultSet 解析出错");
System.out.println(e);
return null ;
}
}
}
DataProperty
package Classification;
public class DataProperty {
public double getGini(int[] data){
int len=data.length ;
double sum=0 ;
for(int i=0;i<len;i++) sum+=data[i] ;
double pre_gini=0 ;
for(int i=0;i<len;i++) pre_gini+= (data[i]/sum)*(data[i]/sum) ;
double gini=1-pre_gini ;
return gini ;
}
public double getGini(int a,int b){
double c=a+b ;
double gini=1-(a/c)*(a/c)-(b/c)*(b/c) ;
return gini ;
}
public double getEntropy(int[] data){
int len=data.length ;
double sum=0 ;
for(int i=0;i<len;i++) sum+=data[i] ; //get the summary of all data
double pre_entro=0 ;
for(int i=0;i<len;i++) {
if (data[i]!=0){
pre_entro+=(data[i]/sum)*Math.log(data[i]/sum)/Math.log(2) ;
}
}
double entro=-pre_entro ;
return entro ;
}
public double getEntropy(int ina,int inb){
double a=(double)ina ;
double b=(double)inb ;
double entro ;
if((a*b)!=0){
double c=a+b ;
double a1=(a/c)*mathLog2(a/c) ;
double b1=(b/c)*mathLog2(b/c) ;
entro=-a1-b1 ;
return entro ;
}else{
entro=0 ;
return entro ;
}
}
//inner methods----------------------------------------------------
private static double mathLog(double data,double bottom){
return Math.log(data)/Math.log(bottom) ;
}
private static double mathLog2(double data){
return Math.log(data)/Math.log(2) ;
}
}
生成数据集的matlab程序
clear
clc
a(1:3,1:2000)=0 ;
a(1,1:1200)=20*rand(1200,1) ;
a(2,1:1200)=20*rand(1200,1) ;
a(3,1:1200)=0 ;
a(1,1201:1400)=5+1.5.*randn(200,1) ;
a(2,1201:1400)=15+1.5.*randn(200,1) ;
a(1,1401:1600)=15+1.5.*randn(200,1) ;
a(2,1401:1600)=5+1.5.*randn(200,1) ;
a(1,1601:2000)=10+2.*randn(400,1) ;
a(2,1601:2000)=10+2.*randn(400,1) ;
a(3,1201:2000)=1 ;
plot(a(1,1:1200),a(2,1:1200),'y.')
hold on
plot(a(1,1201:2000),a(2,1201:2000),'r.')
c=a' ;
xlswrite('C:\Users\multiangle\Desktop\origin.xlsx',c,1)
b(1:3,1:1000)=0 ;
b(1,1:600)=20*rand(600,1) ;
b(2,1:600)=20*rand(600,1) ;
b(3,1:600)=0 ;
b(1,601:700)=5+1.5.*randn(100,1) ;
b(2,601:700)=15+1.5.*randn(100,1) ;
b(1,701:800)=15+1.5.*randn(100,1) ;
b(2,701:800)=5+1.5.*randn(100,1) ;
b(1,801:1000)=10+2.*randn(200,1) ;
b(2,801:1000)=10+2.*randn(200,1) ;
b(3,601:1000)=1 ;
hold on
plot(b(1,1:600),b(2,1:600),'g.')
hold on
plot(b(1,601:1000),b(2,601:1000),'b.')
d=b'
xlswrite('C:\Users\multiangle\Desktop\origin.xlsx',d,2)