机器学习之决策树

    在之前说了用线性回归的方法来对训练数据进行训练,然后通过得到的方程式来对测试数据进行了测试,这里就介绍下,自己对于同样的问题而进行决策树的划分构造树结构。

在这里就不重复说训练数据的格式了,可以看看我之前写的线性回归的那一篇文章。

决策树的步骤:

一、 实验要求

对数据使用决策树的方法对鸢尾花进行分类,进行实验比较、精度比较并写成实验报告

二、 实验原理

决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。

决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别

决策树是一种十分常用的分类方法。他是一种监管学习,所谓监管学习就是给定一堆样本,每个样本都有一组属性和一个类别,这些类别是事先确定的,那么通过学习得到一个分类器,这个分类器能够对新出现的对象给出正确的分类。这样的机器学习就被称之为监督学习

决策数有两大优点:1)决策树模型可以读性好,具有描述性,有助于人工分析;2)效率高,决策树只需要一次构建,反复使用,每一次预测的最大计算次数不超过决策树的深度。

 

三丶 实验思路    

 由之前(上一篇文章)的训练数据可以看出,该数据类型是属于数字型(还有一种叫做名称型),所以对于这样的数据的划分,应该采取“>=”“>”,“<”“<=”作为分割条件,这样能对树的构建更新明显和优化时间复杂度。

决策树模型的构建:

通过如下的表达式来进行一步步的树模型的构建.

Gini不纯度


熵(Entropy


信息增益(Information Gain


四丶算法实现(采用Java语言,基于Eclipse平台

 1:读取测试集数据,并对测试集进行每列数据分割的处理,主要是方便每个特征的划分参数处理。

 

 2:根据熵的计算公式,从而获取到最后一列特征(属于某种类型花)的信息熵

    

 

 3:根据公式,计算每个特征的信息熵,从而确定出最大信息增益熵,以便得到根节点属性。

 

 

其中,因为这是属于数值型的数据,所以自己在对数据进行处理的时候,是首先将每列的数据从小到大进行排序,然后通过迭代循环,依次找到两个连续点的中点值,来作为划分参数,来获得特征的信息熵,并依次对每次的参数得到的信息熵进行比较,从而得到该特征的最大的信息熵。

4:依次对测试集显示的4个属性按照(3)中的方法进行处理

5:在得到每个特征的信息熵之后,通过公式计算出相应的信息增益熵,从而来得到第一层的根节点属性。

6:通过不断的上述步骤,依次得到每层的叶节点和根节点的划分。

7:将上述步骤中得到的每层相对应的特征名,划分参数,所属类别,进行排序处理,方便之后打印出每层树的结构。(就是对参数的排序)

 

8:打印出决策树的结构

 

9:对测试集通过决策树来进行预测,得到准确度。

 

五丶 实验结果

   通过上述的步骤,从而得到以下的输出结果,其中包括树的结构还有自己选取部分测试数据进行测试的结果显示。

决策树的结构:

   

 

部分数据的结果分析(自己在训练集中独立分割的部分数据):

 

代码也贴出来吧(有比较多冗余的地方,没优化,需要的就看看下)

package machinetest;
/*
 * 进行决策树的构建
 * 
 */
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

//计算信息熵
public class ComputeInfoValues {
	static List<String> cunchuallinfo=new ArrayList<>(); //用来保存所有的数据
	static List<String> inputetextdata=new ArrayList<>();//测试数据
	static double[] everyvalue=new double[4];  //保存要进行输出的决策树的变量
	static  double[] maxvalue=new double[4];     //存储划分的参数
	static String[] tezhengname=new String[4];   //存储对应的特征名字
	static double Max_value=0;
      public static void main(String[] args){
   	  
    	  //读取txt文件,统计数据
    	  String url="H:\\Iris.txt";         //数据源txt路径
    	  cunchuallinfo=getAllInfo(url);    //得到输入数据源数据   	  
    	  
    	  //得到根的信息熵
    	  double geninfoshang=showGenShang();   	     	
    	  //得到Sepal.Length的信息熵
    	  double firstsepallength=showDiffierentShang(geninfoshang,0);
    	  double firstvalues=Max_value;     		//得到区分值(参数)
    	  everyvalue[0]=firstsepallength;
    	  maxvalue[0]=firstvalues;
    	  tezhengname[0]="Sepal.Length";
//    	  System.out.println("第一个特征的增益熵:"+firstsepallength);
//    	  System.out.println(Max_value);
    	  //得到Sepal.Width的信息熵
    	  double secondSepalWidth=showDiffierentShang(geninfoshang,1);
    	  double secondvalues=Max_value;
    	  everyvalue[1]=secondSepalWidth;
    	  maxvalue[1]=secondvalues;
    	  tezhengname[1]="Sepal.Width";
//    	  System.out.println("第二个特征的增益熵:"+secondSepalWidth);
//    	  System.out.println(Max_value);
    	  //得到Sepal.Width的信息熵
    	  double threadPetalLength=showDiffierentShang(geninfoshang,2);
    	  double threevalues=Max_value;
    	  everyvalue[2]=threadPetalLength;
    	  maxvalue[2]=threevalues;
    	  tezhengname[2]="Petal.Length";
//    	  System.out.println("第三个特征的增益熵:"+threadPetalLength);
//    	  System.out.println(Max_value);
    	  // 得到PetalWidth的信息熵
    	  double fourPetalWidth=showDiffierentShang(geninfoshang,3);
    	  double fourvalues=Max_value;
    	  everyvalue[3]=fourPetalWidth;
    	  maxvalue[3]=fourvalues;
    	  tezhengname[3]="Petal.Width";
//    	  System.out.println("第四个特征的增益熵:"+fourPetalWidth);
//    	  System.out.println(Max_value);    	 

    	  //找到对应要输出的内容,进行排序,方便打印
    	  dealShunXuValue(everyvalue,maxvalue,tezhengname);
    	  //打印决策树
    	  outputTreeConstruction(everyvalue,maxvalue,tezhengname); //输出决策树的结构
    	  judgeDataResult(firstvalues,secondvalues,threevalues,fourvalues);      //进行决策树测试
    	 
    	  //   	 
      }
     
      
      //将要进行输出的决策树的变量进行整合,也就是按信息增益熵从小到大进行排序
      private static void dealShunXuValue(double[] everyvalue,double[] maxvalue, String[] tezhengname) {
		for(int i=0;i<4;i++){   //冒泡排序(大的放在前面)
			for(int m=3;m>i;m--){
				if(everyvalue[m]>everyvalue[m-1]){
					//进行交换(信息增益熵)
					double temp=everyvalue[m];
					everyvalue[m]=everyvalue[m-1];
					everyvalue[m-1]=temp;
					//交换对应的特征名字
					String temp2=tezhengname[m];
					tezhengname[m]=tezhengname[m-1];
					tezhengname[m-1]=temp2;
					//交换对应的划分参数
					double temp3=maxvalue[m];
					maxvalue[m]=maxvalue[m-1];
					maxvalue[m-1]=temp3;
				}
			}
		}
	}

	 //输出决策树的结构(之前已经按照排序好的顺序进行输出)
      private static void outputTreeConstruction(double[] everyvalue,double[] maxvalue, String[] tezhengname) {
    	  //输出决策树的结构
    	  System.out.println("第一层:                                              "+ tezhengname[0]);   
    	  System.out.println("                           /      \\");
    	  System.out.println("                   (<"+maxvalue[0]+") /        \\"+" (>="+maxvalue[0]+")");
    	  System.out.println("                         /          \\");
    	  System.out.println("第二层:                            "+tezhengname[1]+"        (Iris-versicolor)" );
    	  System.out.println("                         /");
    	  System.out.println("               (<"+maxvalue[1]+") /    \\"+" (>="+maxvalue[1]+")");
    	  System.out.println("                     /      \\");
    	  System.out.println("第三层:             "+ tezhengname[2]+"      (Iris-versicolor)  ");    	  
    	  System.out.println("                   / ");
    	  System.out.println("         (<"+maxvalue[2]+") /    \\"+" (>="+maxvalue[2]+")");
    	  System.out.println("               /      \\");
    	  System.out.println("第四层:        "+tezhengname[3]+"   (Iris-versicolor) ");
    	  System.out.println("                /");
    	  System.out.println("       (<"+maxvalue[3]+") /   \\"+" (>="+maxvalue[3]+")");
    	  System.out.println("第五层:                /      \\");
    	  System.out.println("(Iris-versicolor)   (Iris-setosa) ");		
	}

	//对测试集数据进行测试
	private static void judgeDataResult(double firstvalues,
			double secondvalues, double threevalues, double fourvalues) {
		 进行测试集的比较
  	  String urlinput="H:\\text.txt";                             //测试集txt路径
  	  inputetextdata=getAllInfo(urlinput);                        //得到输入数据源数据
  	  int panduansuoyin=0;
  	  String result="";
  	  String[] ceshidatafinallylie=new String[inputetextdata.size()];  //存储测试数据中的最后一列,方便最后查看结果
  	  String[] panduanshuju=new String[inputetextdata.size()];        //存储通过决策树判断的结果(只需要存储最后一列)
  	  //进行测试
  	  for(int i=0;i<inputetextdata.size();i++){
  		  String[] text=inputetextdata.get(i).split(",");
  		  ceshidatafinallylie[panduansuoyin]=text[4]; //将测试集的最后一列进行存储
  		  if(Double.parseDouble(text[2])>threevalues){     //第三个特征大于MAX的情况   			    			      			  
  			  if(Double.parseDouble(text[3])>=fourvalues){
  				  panduanshuju[panduansuoyin]="Iris-versicolor";
  				  panduansuoyin++;
  			  } 
  			  else{                                            //第三个特征小于MAX的情况   
  				  if(Double.parseDouble(text[0])>firstvalues){   //第一个特征大于MAX的情况 (因为第一个比第二个特征增益熵大)  
  					  panduanshuju[panduansuoyin]="Iris-versicolor";
      				  panduansuoyin++;
  				  }
  				  else{                             //第一个特征小于MAX的情况  
  					   if(Double.parseDouble(text[1])>secondvalues){
  						   panduanshuju[panduansuoyin]="Iris-setosa";
  	        				  panduansuoyin++;
  					   }
  					   else{
  						   panduanshuju[panduansuoyin]="Iris-versicolor";
	        				  panduansuoyin++;
  					   }
  				  }	    			      
  			  }
  		  }
  		  else{                                  //第三个特征小于Max的情况 
  			  if(Double.parseDouble(text[3])>=fourvalues){       //第四个特征大于MAX的情况
  				  panduanshuju[panduansuoyin]="Iris-versicolor";
  				  panduansuoyin++;   	    			
  			  }
  			  else{
  				  panduanshuju[panduansuoyin]="Iris-setosa";      //第四个特征小于MAX的情况
  				  panduansuoyin++;	    			    
  			  }
  		  }
  	  }
  	  
  	  //进行比较结果
	  System.out.println("预测值\t\t\t\t实际值:\t\t\t\t结果");
  	  String panduanresult="";
  	  double totalrightnumber=0;
  	  double totalerrornumber=0;
	  for(int i=0;i<panduanshuju.length;i++){
  	      if(panduanshuju[i].equals(ceshidatafinallylie[i])){
  	    	  panduanresult="right";
  	    	totalrightnumber++;
  	      }
  	      else{
  	    	  panduanresult="error";
  	    	totalerrornumber++;
  	      }
  		  System.out.println(panduanshuju[i]+"\t\t\t"+ceshidatafinallylie[i]+"\t\t\t"+panduanresult);
  	  }
		//输出精确度
	  	//输出精准度(结果保留两位小数)
		double d = (totalrightnumber)/(totalerrornumber+totalrightnumber)*100;
		String resultdata = String.format("%.2f", d);				
		System.out.println("预测的精确度是:"+resultdata+"%");
}

	//求每个特征的信息增益熵
      private static double showDiffierentShang(double geninfoshang,int charpterindex) {
    	  //将训练数据进行排序(方便找到参数点,得到最大增益熵)
    	  double Max_zengyishang=0;
    	  double[] textdata=sortAllInfoData(cunchuallinfo,charpterindex);
    	  for(int i=1;i<textdata.length;i++){  //找到合适的参数,并得到最大的增益熵的数
	    	  double testnumber=(textdata[i]+textdata[i-1])/2;
	    	  int fuheerror=getFuheNumberInfo(cunchuallinfo,testnumber,charpterindex); 
	    	  int sepalerrortop=getTopNumberInfo(cunchuallinfo,testnumber,charpterindex,"Iris-versicolor");  //大于5.0的个数
	    	  int sepalrighttop=getTopNumberInfo(cunchuallinfo,testnumber,charpterindex,"Iris-setosa");
	
	    	  int fuheright=cunchuallinfo.size()-1-fuheerror;     
	    	  int sepalrightbottom=getBottomNumberInfo(cunchuallinfo,testnumber,charpterindex,"Iris-setosa");
	    	  int sepalerrorbottom=getBottomNumberInfo(cunchuallinfo,testnumber,charpterindex,"Iris-versicolor");
	
	    	  double SepallengthshangTop=computeShang(sepalerrortop, sepalrighttop);
	    	  double SepallengthshangBottom=computeShang(sepalerrorbottom,sepalrightbottom);
	
	    	 double computershangsecond=computeDevolopShang(geninfoshang,SepallengthshangBottom,SepallengthshangTop,cunchuallinfo.size(),fuheerror,fuheright);         
			 if(Max_zengyishang<=computershangsecond){  //判断是否是最大的增益熵
				 Max_zengyishang=computershangsecond;
				 Max_value=textdata[i];               //获取到单列中让增益熵最大的值
			 }	    	
    	  }
    	  return Max_zengyishang;
	}
      
//	//显示第一个特征的信息增益熵
//	private static double showSepalLengthShang(double geninfoshang) {
//				
//    	  int fuheerror=getFuheNumberInfo(cunchuallinfo,5.0,0);  //大于5.0的个数
//    	  int sepalerrortop=getTopNumberInfo(cunchuallinfo,5.0,0,"Iris-versicolor");  //大于5.0的个数
//    	  int sepalrighttop=getTopNumberInfo(cunchuallinfo,5.0,0,"Iris-setosa");
//    	  
//    	  int fuheright=cunchuallinfo.size()-1-fuheerror;      //小于5.0的个数
//    	  int sepalrightbottom=getBottomNumberInfo(cunchuallinfo,5.0,0,"Iris-setosa");
//    	  int sepalerrorbottom=getBottomNumberInfo(cunchuallinfo,5.0,0,"Iris-versicolor");
//    	  
//    	  double SepallengthshangTop=computeShang(sepalerrortop, sepalrighttop);
//    	  double SepallengthshangBottom=computeShang(sepalerrorbottom, sepalrightbottom);
//
    	  double zengyilv=computeIvValue(fuheerror,fuheright);    //计算增益率Iv
    	  System.out.println(zengyilv);
//    	  
//    	 double computershangfirst=computeDevolopShang(geninfoshang,SepallengthshangBottom,SepallengthshangTop,cunchuallinfo.size(),fuheerror,fuheright);         
//		 return computershangfirst;
//	}


	//得到IV值
	private static double computeIvValue(double value, double base) {
		double number1=((value/17)*Math.log((value/17))/Math.log(2));
		double number2=(value/17)*Math.log((value/17))/Math.log(2);
		double number=-(number1+number2);
		return number;
	}


	//得到根的信息熵
     private static double showGenShang() {
     int Irisnumber=findNeedNumberInfo(cunchuallinfo,"Iris-setosa",4);  //得到Iris-setosa的个数
   	 int noIrisnumber=findNeedNumberInfo(cunchuallinfo,"Iris-versicolor",4);  //得打所有反例的个数
         //得到根节点的信息熵
   	  double geninfoshang=computeShang(Irisnumber,noIrisnumber);
   	  return geninfoshang;
	}


	//计算增益熵
    private static double computeDevolopShang(double genshang,
			double sepallengthshangBottom, double sepallengthshangTop,
			int size, double fuheerror, double fuheright) {
    	  	
		double number1=fuheerror/(size-1)*sepallengthshangTop;
		double number2=fuheright/(size-1)*sepallengthshangBottom;
		double allnumber=genshang-number1-number2;
		return allnumber;
	}


	//计算特征属性的不同性质的个数(取符合规定属性的下部分)
      private static int getBottomNumberInfo(List<String> cunchu,double number, int suoyin, String str) {
    	  	int index=1;  //主要是第一行的那字母不需要
    		int totalnumber=0;
    		while(index<cunchu.size()){
    			String[] everyzifu=cunchu.get(index).split(",");
    			if((Double.parseDouble(everyzifu[suoyin]))<number&&str.equals(everyzifu[4])){
    				totalnumber++;			
    			}
    				index++;
    		}				
    		return totalnumber;
	}

	//获取每种符合的总个数
      private static int getFuheNumberInfo(List<String> cunchu, double number,int suoyin) {
    	  int index=1;                    //主要是第一行的那字母不需要
    		int totalnumber=0;
    		while(index<cunchu.size()){
    			String[] everyzifu=cunchu.get(index).split(",");
    			if((Double.parseDouble(everyzifu[suoyin]))>=number){
    				totalnumber++;			
    			}
    				index++;
    		}				
    		return totalnumber;
	}

	//计算特征属性的不同性质的个数(取符合规定属性的上部分)
      private static int getTopNumberInfo(List<String> cunchu, double number,int suoyin,String irisnumber2) {
    	 int index=1;  //主要是第一行的那字母不需要
  		int totalnumber=0;
  		while(index<cunchu.size()){
  			String[] everyzifu=cunchu.get(index).split(",");
  			if((Double.parseDouble(everyzifu[suoyin]))>number&& irisnumber2.equals(everyzifu[4])){
  				totalnumber++;			
  			}
  				index++;
  		}				
  		return totalnumber;
	}

	//计算熵值
    private static double computeShang(double value, double base) {
    	double number=0;
    	if(value==0||base==0){   //表示纯度很高
    		number=0;
    	}
    	else{
	    	double number1=(value/(value+base))*Math.log((value/(value+base)))/Math.log(2);	
	    	double number2=(base/(value+base))*Math.log((base/(value+base)))/Math.log(2);	
	    	number=-(number1+number2);
    	}
    			return number;
	}
    
    //将每列数据进行排序
    private static double[] sortAllInfoData(List<String> cunchu, int charpterindex) {
    	double[] paixu=new double[cunchu.size()]; //存储每类的数据
    	int number=0;
    	for(int i=1;i<cunchu.size();i++){
    		String[] sortzifu=cunchu.get(i).split(",");//得到每个字符    		
    		paixu[number]=Double.parseDouble(sortzifu[charpterindex]);
    		number++;
    	}
    	//对数据进行排序
       Arrays.sort(paixu);
       
		return paixu;
	}
    
    

	//获得需要参数的个数
	private static int findNeedNumberInfo(List<String> cunchuallinfo2, String compare,int suoyin) {
		int index=1;  //主要是第一行的那字母不需要
		int totalnumber=0;
		while(index<cunchuallinfo2.size()){
			String[] everyzifu=cunchuallinfo2.get(index).split(",");
			if(compare.equals(everyzifu[suoyin])){
				totalnumber++;			
			}
				index++;
		}				
		return totalnumber;
	}


	//读取txt文件,统计数据
	private static  ArrayList<String> getAllInfo(String filePath) {	
		ArrayList<String> infodata=new ArrayList<>();
		try {
			String encoding = "UTF-8"; //设置编码
			File file = new File(filePath);
			if (file.isFile() && file.exists()) { // 判断文件是否存在
				InputStreamReader read = new InputStreamReader(
						new FileInputStream(file), encoding);     // 考虑到编码格式
				BufferedReader bufferedReader = new BufferedReader(read);
				String lineTxt = null;
				while ((lineTxt = bufferedReader.readLine()) != null) {  //读取的行数内容不是空
					infodata.add(lineTxt);                   //把数据存到数组中
				}
				read.close();
			} else {
				System.out.println("找不到指定的文件");
			}
		} catch (Exception e) {
			System.out.println("读取文件内容出错");
			e.printStackTrace();
		}
		return infodata; //返回所有的数据		
	}
}

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
决策树是常用的机器学习算法之一,通过对数据的分类和特征值计算来完成对未知数据的预测。本文将介绍使用Python实现决策树算法的相关步骤。 首先,需要导入决策树算法工具包,使用以下代码: ```python from sklearn import tree ``` 然后,导入训练数据和测试数据,并进行预处理。为了方便起见,在本文中采用生成随机数的方式来生成样本数据,使用以下代码: ```python from sklearn.datasets import make_classification X, y = make_classification(n_samples=100, n_features=4, n_classes=2, n_informative=2, n_redundant=0, random_state=0, shuffle=False) ``` 接下来,使用生成的样本数据进行模型训练。这里使用scikit-learn中的DecisionTreeClassifier()函数。 ```python clf = tree.DecisionTreeClassifier() clf = clf.fit(X, y) ``` 训练后,调用predict()方法进行对测试数据的预测,使用以下代码: ```python y_pred = clf.predict(X) ``` 最后,评估模型的准确率,使用以下代码: ```python from sklearn.metrics import accuracy_score print(accuracy_score(y, y_pred)) ``` 这就是使用Python实现决策树算法的基本过程。决策树可以根据数据中的不同特征进行分类,是一个简单且常用的分类算法。决策树算法也可用于回归问题,例如预测一个数的大小。与其他机器学习算法相比,决策树具有易于理解和可解释的优点,同时还可以处理非线性的分类问题。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值