决策树的结构:每个内部节点表示在一个属性上的测试,每个分支代表一个属性的输出,每个树叶节点代表类或类分布,树的最顶层是根节点。
实例:由年龄、收入、是否是学生、信用来预测是否买电脑
RID | Age | income | student | Credit_rating | Class:buys_computer |
1 | youth | high | no | fair | no |
2 | youth | high | no | excellent | no |
3 | Middle_aged | high | no | fair | yes |
4 | senior | medium | no | fair | yes |
5 | senior | low | yes | fair | yes |
6 | senior | low | yes | excellent | no |
7 | Middle_aged | low | yes | excellent | yes |
8 | youth | medium | no | fair | no |
9 | youth | low | yes | fair | yes |
10 | youth | medium | yes | fair | yes |
11 | youth | medium | yes | excellent | yes |
12 | Middle_aged | medium | no | excellent | yes |
13 | Middle_aged | high | yes | fair | yes |
14 | senior | medium | no | excellent | no |
熵的概念
我们对一件事情越了解,所需要的信息就越少,反之就越多。所以,变量的不确定性越大,熵
就越大。
获取信息量:Gain(A) = info(D) - Info_A(D),即没有A时候的信息量减去有A之后的信息量
根节点的计算:以上图为例,14种情况中,最终买电脑的为9个,不买电脑为5个,则
以年龄来算信息熵,youth共5个,yes共2个,no共2个,同样计算senior和Middle_aged:
同理计算income、student、Credit_rating的信息熵:
最大所以选age为根节点,然后在每一个分支上再继续计算剩下的信息熵,直到最后。
ID3算法:
1、树以代表训练样本的单个节点开始
2、如果样本都在同一类,则该节点成为树叶,并用该类标号
3、否则,算法使用称为信息增益的基于熵的度量作为启发信息,选择能够最好的将样本分类的属性,该属性成为节点的‘测试’或判定属性。
4、在算法的该版本中,所有属性都是分类的,即离散值,连续属性必须离散化,
5、对测试属性的每个已知的值,创建一个分支,并据此划分样本。
6、算法使用同样的过程,递归地形成每个划分上的样本判定树,一旦一个属性出现在一个样本上,就不必该节点的任何后代考虑他。
7、递归划分步骤仅当下条件成立停止
(1)给定节点的所有样本属于同一类
(2)没有剩余属性可以用来进一步划分样本。在此情况下使用多数表决,这涉及将给定的节点转化成树叶,并用样本中的多数所在的类标记他,替换他,可以存放节点样本的类分布。
(3)分支test_attribute = a,没有样本,此情况下,以sample中的多数类创建树叶
关于剪枝:
(1)先剪枝
(2)后剪枝
决策树的优点:直观便于理解,小规模数据集有效
缺点:处理连续变量不好,类别较多时,错误增加的比较快,可规模型一般
程序:
from sklearn.feature_extraction import DictVectorizer
import csv
from sklearn import preprocessing
from sklearn import tree
# from sklearn.externals.six import StringIO
allElectronicsDate = open(r'.\Class_buys_computer.csv', 'rt')
reader = csv.reader(allElectronicsDate) # CSV模块自带的reader方法,可按行读取内容
# print('reader:'+ str(reader))
headers = next(reader) #读出的是表头部分:['RID', 'Age', 'income', 'student', 'Credit_rating', 'Class:buys_computer']
# print(headers)
featureList = [] #装特征分类['Age', 'income', 'student', 'Credit_rating']
labelList = [] #装类别[ 'Class:buys_computer']
for row in reader: #执行每一行的数据
# print(row)
labelList.append(row[len(row) - 1]) #取每一行最后一个值,即class label将其装进labelList
# print(labelList)
rowDict = {} #创建字典
for i in range(1, len(row) - 1):
# print(row[i])
# 字典的key就是csv中的age对应的属性['Age', 'income', 'student', 'Credit_rating']
#值就是具体的属性
rowDict[headers[i]] = row[i]
# print('rowDict',rowDict)
featureList.append(rowDict)
print(featureList)
print(labelList)
vec = DictVectorizer() # python自带模块
dummyX = vec.fit_transform(featureList).toarray()
# 调用方法fit_transform将字典类型的[{'a':'b'},{'c':'d'}]数据中的'b','d'数据转换成0,1的矩阵形式
# print("dummyX:" + str(dummyX))
# print(vec.get_feature_names()) # 调用此方法得到'b','d'对应的特征名
print("labellist:" + str(labelList))
lb = preprocessing.LabelBinarizer() # python内部模块
dummyY = lb.fit_transform(labelList) # 调用fit_transform方法将标签列表中的数据转成0,1格式
print("dummyY:" + str(dummyY))
clf = tree.DecisionTreeClassifier(criterion='entropy')
# tree模块,创建clf分类器,entropy表示度量标准信息熵
clf = clf.fit(dummyX, dummyY)
# 用训练数据dummyX,dummyY拟合分类器模型
print("clf:" + str(clf))
with open("allElectronicInformationGainOri.dot", 'w') as f:
f = tree.export_graphviz(clf, feature_names=vec.get_feature_names(), out_file=f)
# 通过export_graphviz模块导出dot文件到1.dot文件中,后通过cmd命令dot -Tpdf 1.dot -o 1.pdf
# 将dot文件转化成pdf视图
# ↓↓↓造一组新数据,来预测分类
oneRowX = dummyX[0, :] # 取X矩阵数组里面的第一行
print("oneRowX:" + str(oneRowX))
newRowX = oneRowX # 赋给新标签
newRowX[0] = 1 #原始矩阵数组里面的第一行
newRowX[2] = 0 #将第三个属性student改为0
print("newRowX:" + str(newRowX))
# newRowX:[ 1. 0. 0. 0. 1. 1. 0. 0. 1. 0.]
newRowX = newRowX.reshape(1, -1)
# 将列表转化为矩阵,共predict调用
print("newRowX:" + str(newRowX))
# newRowX:[[ 1. 0. 0. 0. 1. 1. 0. 0. 1. 0.]]
predictedY = clf.predict(newRowX)
# 用之前创建好的分类器clf(classifier),newRowX必须是矩阵类型
print("predictedY:" + str(predictedY))
运行完之后会在当前目录生成一个allElectronicInformationGainOri.dot文件,此处需要安装:
Graphviz工具
具体安装过程见:https://blog.csdn.net/qq_42006303/article/details/99196645
官网下载较慢,这里提供链接可直接下载:
然后用如下命令将allElectronicInformationGainOri.dot文件转化为Source.gv.pdf文件
import graphviz
with open("allElectronicInformationGainOri.dot") as f:
dot_graph = f.read()
dot=graphviz.Source(dot_graph)
dot.view()
结果如下:
相关代码和文件:https://download.csdn.net/download/qq_42006303/11522659