在之前说了用线性回归的方法来对训练数据进行训练,然后通过得到的方程式来对测试数据进行了测试,这里就介绍下,自己对于同样的问题而进行决策树的划分构造树结构。
在这里就不重复说训练数据的格式了,可以看看我之前写的线性回归的那一篇文章。
决策树的步骤:
一、 实验要求
对数据使用决策树的方法对鸢尾花进行分类,进行实验比较、精度比较并写成实验报告。
二、 实验原理
决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。
决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。
决策树是一种十分常用的分类方法。他是一种监管学习,所谓监管学习就是给定一堆样本,每个样本都有一组属性和一个类别,这些类别是事先确定的,那么通过学习得到一个分类器,这个分类器能够对新出现的对象给出正确的分类。这样的机器学习就被称之为监督学习。
决策数有两大优点:1)决策树模型可以读性好,具有描述性,有助于人工分析;2)效率高,决策树只需要一次构建,反复使用,每一次预测的最大计算次数不超过决策树的深度。
三丶 实验思路
由之前(上一篇文章)的训练数据可以看出,该数据类型是属于数字型(还有一种叫做名称型),所以对于这样的数据的划分,应该采取用“>=”,“>”,“<”或“<=”作为分割条件,这样能对树的构建更新明显和优化时间复杂度。
决策树模型的构建:
通过如下的表达式来进行一步步的树模型的构建.
Gini不纯度
熵(Entropy)
信息增益(Information Gain)
四丶算法实现(采用Java语言,基于Eclipse平台)
1:读取测试集数据,并对测试集进行每列数据分割的处理,主要是方便每个特征的划分参数处理。
2:根据熵的计算公式,从而获取到最后一列特征(属于某种类型花)的信息熵
3:根据公式,计算每个特征的信息熵,从而确定出最大信息增益熵,以便得到根节点属性。
其中,因为这是属于数值型的数据,所以自己在对数据进行处理的时候,是首先将每列的数据从小到大进行排序,然后通过迭代循环,依次找到两个连续点的中点值,来作为划分参数,来获得特征的信息熵,并依次对每次的参数得到的信息熵进行比较,从而得到该特征的最大的信息熵。
4:依次对测试集显示的4个属性按照(3)中的方法进行处理
5:在得到每个特征的信息熵之后,通过公式计算出相应的信息增益熵,从而来得到第一层的根节点属性。
6:通过不断的上述步骤,依次得到每层的叶节点和根节点的划分。
7:将上述步骤中得到的每层相对应的特征名,划分参数,所属类别,进行排序处理,方便之后打印出每层树的结构。(就是对参数的排序)
8:打印出决策树的结构
9:对测试集通过决策树来进行预测,得到准确度。
五丶 实验结果
通过上述的步骤,从而得到以下的输出结果,其中包括树的结构还有自己选取部分测试数据进行测试的结果显示。
决策树的结构:
部分数据的结果分析(自己在训练集中独立分割的部分数据):
代码也贴出来吧(有比较多冗余的地方,没优化,需要的就看看下)
package machinetest;
/*
* 进行决策树的构建
*
*/
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
//计算信息熵
public class ComputeInfoValues {
static List<String> cunchuallinfo=new ArrayList<>(); //用来保存所有的数据
static List<String> inputetextdata=new ArrayList<>();//测试数据
static double[] everyvalue=new double[4]; //保存要进行输出的决策树的变量
static double[] maxvalue=new double[4]; //存储划分的参数
static String[] tezhengname=new String[4]; //存储对应的特征名字
static double Max_value=0;
public static void main(String[] args){
//读取txt文件,统计数据
String url="H:\\Iris.txt"; //数据源txt路径
cunchuallinfo=getAllInfo(url); //得到输入数据源数据
//得到根的信息熵
double geninfoshang=showGenShang();
//得到Sepal.Length的信息熵
double firstsepallength=showDiffierentShang(geninfoshang,0);
double firstvalues=Max_value; //得到区分值(参数)
everyvalue[0]=firstsepallength;
maxvalue[0]=firstvalues;
tezhengname[0]="Sepal.Length";
// System.out.println("第一个特征的增益熵:"+firstsepallength);
// System.out.println(Max_value);
//得到Sepal.Width的信息熵
double secondSepalWidth=showDiffierentShang(geninfoshang,1);
double secondvalues=Max_value;
everyvalue[1]=secondSepalWidth;
maxvalue[1]=secondvalues;
tezhengname[1]="Sepal.Width";
// System.out.println("第二个特征的增益熵:"+secondSepalWidth);
// System.out.println(Max_value);
//得到Sepal.Width的信息熵
double threadPetalLength=showDiffierentShang(geninfoshang,2);
double threevalues=Max_value;
everyvalue[2]=threadPetalLength;
maxvalue[2]=threevalues;
tezhengname[2]="Petal.Length";
// System.out.println("第三个特征的增益熵:"+threadPetalLength);
// System.out.println(Max_value);
// 得到PetalWidth的信息熵
double fourPetalWidth=showDiffierentShang(geninfoshang,3);
double fourvalues=Max_value;
everyvalue[3]=fourPetalWidth;
maxvalue[3]=fourvalues;
tezhengname[3]="Petal.Width";
// System.out.println("第四个特征的增益熵:"+fourPetalWidth);
// System.out.println(Max_value);
//找到对应要输出的内容,进行排序,方便打印
dealShunXuValue(everyvalue,maxvalue,tezhengname);
//打印决策树
outputTreeConstruction(everyvalue,maxvalue,tezhengname); //输出决策树的结构
judgeDataResult(firstvalues,secondvalues,threevalues,fourvalues); //进行决策树测试
//
}
//将要进行输出的决策树的变量进行整合,也就是按信息增益熵从小到大进行排序
private static void dealShunXuValue(double[] everyvalue,double[] maxvalue, String[] tezhengname) {
for(int i=0;i<4;i++){ //冒泡排序(大的放在前面)
for(int m=3;m>i;m--){
if(everyvalue[m]>everyvalue[m-1]){
//进行交换(信息增益熵)
double temp=everyvalue[m];
everyvalue[m]=everyvalue[m-1];
everyvalue[m-1]=temp;
//交换对应的特征名字
String temp2=tezhengname[m];
tezhengname[m]=tezhengname[m-1];
tezhengname[m-1]=temp2;
//交换对应的划分参数
double temp3=maxvalue[m];
maxvalue[m]=maxvalue[m-1];
maxvalue[m-1]=temp3;
}
}
}
}
//输出决策树的结构(之前已经按照排序好的顺序进行输出)
private static void outputTreeConstruction(double[] everyvalue,double[] maxvalue, String[] tezhengname) {
//输出决策树的结构
System.out.println("第一层: "+ tezhengname[0]);
System.out.println(" / \\");
System.out.println(" (<"+maxvalue[0]+") / \\"+" (>="+maxvalue[0]+")");
System.out.println(" / \\");
System.out.println("第二层: "+tezhengname[1]+" (Iris-versicolor)" );
System.out.println(" /");
System.out.println(" (<"+maxvalue[1]+") / \\"+" (>="+maxvalue[1]+")");
System.out.println(" / \\");
System.out.println("第三层: "+ tezhengname[2]+" (Iris-versicolor) ");
System.out.println(" / ");
System.out.println(" (<"+maxvalue[2]+") / \\"+" (>="+maxvalue[2]+")");
System.out.println(" / \\");
System.out.println("第四层: "+tezhengname[3]+" (Iris-versicolor) ");
System.out.println(" /");
System.out.println(" (<"+maxvalue[3]+") / \\"+" (>="+maxvalue[3]+")");
System.out.println("第五层: / \\");
System.out.println("(Iris-versicolor) (Iris-setosa) ");
}
//对测试集数据进行测试
private static void judgeDataResult(double firstvalues,
double secondvalues, double threevalues, double fourvalues) {
进行测试集的比较
String urlinput="H:\\text.txt"; //测试集txt路径
inputetextdata=getAllInfo(urlinput); //得到输入数据源数据
int panduansuoyin=0;
String result="";
String[] ceshidatafinallylie=new String[inputetextdata.size()]; //存储测试数据中的最后一列,方便最后查看结果
String[] panduanshuju=new String[inputetextdata.size()]; //存储通过决策树判断的结果(只需要存储最后一列)
//进行测试
for(int i=0;i<inputetextdata.size();i++){
String[] text=inputetextdata.get(i).split(",");
ceshidatafinallylie[panduansuoyin]=text[4]; //将测试集的最后一列进行存储
if(Double.parseDouble(text[2])>threevalues){ //第三个特征大于MAX的情况
if(Double.parseDouble(text[3])>=fourvalues){
panduanshuju[panduansuoyin]="Iris-versicolor";
panduansuoyin++;
}
else{ //第三个特征小于MAX的情况
if(Double.parseDouble(text[0])>firstvalues){ //第一个特征大于MAX的情况 (因为第一个比第二个特征增益熵大)
panduanshuju[panduansuoyin]="Iris-versicolor";
panduansuoyin++;
}
else{ //第一个特征小于MAX的情况
if(Double.parseDouble(text[1])>secondvalues){
panduanshuju[panduansuoyin]="Iris-setosa";
panduansuoyin++;
}
else{
panduanshuju[panduansuoyin]="Iris-versicolor";
panduansuoyin++;
}
}
}
}
else{ //第三个特征小于Max的情况
if(Double.parseDouble(text[3])>=fourvalues){ //第四个特征大于MAX的情况
panduanshuju[panduansuoyin]="Iris-versicolor";
panduansuoyin++;
}
else{
panduanshuju[panduansuoyin]="Iris-setosa"; //第四个特征小于MAX的情况
panduansuoyin++;
}
}
}
//进行比较结果
System.out.println("预测值\t\t\t\t实际值:\t\t\t\t结果");
String panduanresult="";
double totalrightnumber=0;
double totalerrornumber=0;
for(int i=0;i<panduanshuju.length;i++){
if(panduanshuju[i].equals(ceshidatafinallylie[i])){
panduanresult="right";
totalrightnumber++;
}
else{
panduanresult="error";
totalerrornumber++;
}
System.out.println(panduanshuju[i]+"\t\t\t"+ceshidatafinallylie[i]+"\t\t\t"+panduanresult);
}
//输出精确度
//输出精准度(结果保留两位小数)
double d = (totalrightnumber)/(totalerrornumber+totalrightnumber)*100;
String resultdata = String.format("%.2f", d);
System.out.println("预测的精确度是:"+resultdata+"%");
}
//求每个特征的信息增益熵
private static double showDiffierentShang(double geninfoshang,int charpterindex) {
//将训练数据进行排序(方便找到参数点,得到最大增益熵)
double Max_zengyishang=0;
double[] textdata=sortAllInfoData(cunchuallinfo,charpterindex);
for(int i=1;i<textdata.length;i++){ //找到合适的参数,并得到最大的增益熵的数
double testnumber=(textdata[i]+textdata[i-1])/2;
int fuheerror=getFuheNumberInfo(cunchuallinfo,testnumber,charpterindex);
int sepalerrortop=getTopNumberInfo(cunchuallinfo,testnumber,charpterindex,"Iris-versicolor"); //大于5.0的个数
int sepalrighttop=getTopNumberInfo(cunchuallinfo,testnumber,charpterindex,"Iris-setosa");
int fuheright=cunchuallinfo.size()-1-fuheerror;
int sepalrightbottom=getBottomNumberInfo(cunchuallinfo,testnumber,charpterindex,"Iris-setosa");
int sepalerrorbottom=getBottomNumberInfo(cunchuallinfo,testnumber,charpterindex,"Iris-versicolor");
double SepallengthshangTop=computeShang(sepalerrortop, sepalrighttop);
double SepallengthshangBottom=computeShang(sepalerrorbottom,sepalrightbottom);
double computershangsecond=computeDevolopShang(geninfoshang,SepallengthshangBottom,SepallengthshangTop,cunchuallinfo.size(),fuheerror,fuheright);
if(Max_zengyishang<=computershangsecond){ //判断是否是最大的增益熵
Max_zengyishang=computershangsecond;
Max_value=textdata[i]; //获取到单列中让增益熵最大的值
}
}
return Max_zengyishang;
}
// //显示第一个特征的信息增益熵
// private static double showSepalLengthShang(double geninfoshang) {
//
// int fuheerror=getFuheNumberInfo(cunchuallinfo,5.0,0); //大于5.0的个数
// int sepalerrortop=getTopNumberInfo(cunchuallinfo,5.0,0,"Iris-versicolor"); //大于5.0的个数
// int sepalrighttop=getTopNumberInfo(cunchuallinfo,5.0,0,"Iris-setosa");
//
// int fuheright=cunchuallinfo.size()-1-fuheerror; //小于5.0的个数
// int sepalrightbottom=getBottomNumberInfo(cunchuallinfo,5.0,0,"Iris-setosa");
// int sepalerrorbottom=getBottomNumberInfo(cunchuallinfo,5.0,0,"Iris-versicolor");
//
// double SepallengthshangTop=computeShang(sepalerrortop, sepalrighttop);
// double SepallengthshangBottom=computeShang(sepalerrorbottom, sepalrightbottom);
//
double zengyilv=computeIvValue(fuheerror,fuheright); //计算增益率Iv
System.out.println(zengyilv);
//
// double computershangfirst=computeDevolopShang(geninfoshang,SepallengthshangBottom,SepallengthshangTop,cunchuallinfo.size(),fuheerror,fuheright);
// return computershangfirst;
// }
//得到IV值
private static double computeIvValue(double value, double base) {
double number1=((value/17)*Math.log((value/17))/Math.log(2));
double number2=(value/17)*Math.log((value/17))/Math.log(2);
double number=-(number1+number2);
return number;
}
//得到根的信息熵
private static double showGenShang() {
int Irisnumber=findNeedNumberInfo(cunchuallinfo,"Iris-setosa",4); //得到Iris-setosa的个数
int noIrisnumber=findNeedNumberInfo(cunchuallinfo,"Iris-versicolor",4); //得打所有反例的个数
//得到根节点的信息熵
double geninfoshang=computeShang(Irisnumber,noIrisnumber);
return geninfoshang;
}
//计算增益熵
private static double computeDevolopShang(double genshang,
double sepallengthshangBottom, double sepallengthshangTop,
int size, double fuheerror, double fuheright) {
double number1=fuheerror/(size-1)*sepallengthshangTop;
double number2=fuheright/(size-1)*sepallengthshangBottom;
double allnumber=genshang-number1-number2;
return allnumber;
}
//计算特征属性的不同性质的个数(取符合规定属性的下部分)
private static int getBottomNumberInfo(List<String> cunchu,double number, int suoyin, String str) {
int index=1; //主要是第一行的那字母不需要
int totalnumber=0;
while(index<cunchu.size()){
String[] everyzifu=cunchu.get(index).split(",");
if((Double.parseDouble(everyzifu[suoyin]))<number&&str.equals(everyzifu[4])){
totalnumber++;
}
index++;
}
return totalnumber;
}
//获取每种符合的总个数
private static int getFuheNumberInfo(List<String> cunchu, double number,int suoyin) {
int index=1; //主要是第一行的那字母不需要
int totalnumber=0;
while(index<cunchu.size()){
String[] everyzifu=cunchu.get(index).split(",");
if((Double.parseDouble(everyzifu[suoyin]))>=number){
totalnumber++;
}
index++;
}
return totalnumber;
}
//计算特征属性的不同性质的个数(取符合规定属性的上部分)
private static int getTopNumberInfo(List<String> cunchu, double number,int suoyin,String irisnumber2) {
int index=1; //主要是第一行的那字母不需要
int totalnumber=0;
while(index<cunchu.size()){
String[] everyzifu=cunchu.get(index).split(",");
if((Double.parseDouble(everyzifu[suoyin]))>number&& irisnumber2.equals(everyzifu[4])){
totalnumber++;
}
index++;
}
return totalnumber;
}
//计算熵值
private static double computeShang(double value, double base) {
double number=0;
if(value==0||base==0){ //表示纯度很高
number=0;
}
else{
double number1=(value/(value+base))*Math.log((value/(value+base)))/Math.log(2);
double number2=(base/(value+base))*Math.log((base/(value+base)))/Math.log(2);
number=-(number1+number2);
}
return number;
}
//将每列数据进行排序
private static double[] sortAllInfoData(List<String> cunchu, int charpterindex) {
double[] paixu=new double[cunchu.size()]; //存储每类的数据
int number=0;
for(int i=1;i<cunchu.size();i++){
String[] sortzifu=cunchu.get(i).split(",");//得到每个字符
paixu[number]=Double.parseDouble(sortzifu[charpterindex]);
number++;
}
//对数据进行排序
Arrays.sort(paixu);
return paixu;
}
//获得需要参数的个数
private static int findNeedNumberInfo(List<String> cunchuallinfo2, String compare,int suoyin) {
int index=1; //主要是第一行的那字母不需要
int totalnumber=0;
while(index<cunchuallinfo2.size()){
String[] everyzifu=cunchuallinfo2.get(index).split(",");
if(compare.equals(everyzifu[suoyin])){
totalnumber++;
}
index++;
}
return totalnumber;
}
//读取txt文件,统计数据
private static ArrayList<String> getAllInfo(String filePath) {
ArrayList<String> infodata=new ArrayList<>();
try {
String encoding = "UTF-8"; //设置编码
File file = new File(filePath);
if (file.isFile() && file.exists()) { // 判断文件是否存在
InputStreamReader read = new InputStreamReader(
new FileInputStream(file), encoding); // 考虑到编码格式
BufferedReader bufferedReader = new BufferedReader(read);
String lineTxt = null;
while ((lineTxt = bufferedReader.readLine()) != null) { //读取的行数内容不是空
infodata.add(lineTxt); //把数据存到数组中
}
read.close();
} else {
System.out.println("找不到指定的文件");
}
} catch (Exception e) {
System.out.println("读取文件内容出错");
e.printStackTrace();
}
return infodata; //返回所有的数据
}
}