对iris数据集进行分类,数据长这样:
0,1,2,3列表示特征,4 列表示花所属的类别
只取0,1两列作为分类的依据。
分类结果如下:
代码如下:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import time
class Tree:
def __init__(self, data):
self.left = None
self.right = None
self.data = data
self.feature = 0
self.seperator = 0
self.__contruct();
def __contruct(self):
if len(self.data) <= 1:
return
""" 分别计算前两个特征的gini系数
实践是检验真理的标准——这里应该取最小值
"""
minGini = 2
for i in range(2):
gini, sep = self.__Gini(self.data, i) ;
if gini < minGini:
minGini = gini
self.feature = i;
self.seperator = sep;
data1 = self.data[(self.data[self.feature] < self.seperator)]
data2 = self.data[(self.data[self.feature] >= self.seperator)]
if len(data1) > 0 and len(data2) > 0 :
self.left = Tree(data1)
self.right = Tree(data2)
def __Gini(self, data, feature):
col = data[feature].sort_values().to_numpy()
ms = col[0:-1] + np.diff(col) / 2
minGini = 2
seperator = 0
for i in range(len(ms)):
data1 = data[(data[feature] < ms[i])]
data2 = data[(data[feature] >= ms[i])]
gini = 0
if len(data1) == 0 or len(data2) == 0 :
gini = self.__gini(data)
gini = ( len(data1) * self.__gini(data1) + len(data2) * self.__gini(data2) ) / len(data)
if(gini < minGini):
minGini = gini
seperator = ms[i]
return minGini, seperator
def __gini(self, data):
counts = self.__count(data)
return 1 - sum(counts ** 2)
def __count(self, data):
counts = np.zeros((3, ))
for i in range(len(data)):
label = data.iloc[i, 4]
counts[label] = counts[label] + 1
counts = counts / sum(counts)
return counts
def classify(self, sample):
v = sample[self.feature]
m = self.seperator
if self.left == None and self.right == None:
counts = self.__count(self.data)
indexs = np.where(counts == max(counts))
return indexs[0][0]
elif v < m:
return self.left.classify(sample)
elif v >= m:
return self.right.classify(sample)
else:
return "unkonw"
#读数据
data = pd.read_csv("iris.csv", header=None)
data[4] = pd.Categorical(data[4]).codes
#构造决策树
tree = Tree(data)
#计算分类正确率
correct = 0
for i in range(len(data)):
v = data.iloc[i, 4]
p = tree.classify(data.iloc[i])
if(v == p):
correct = correct + 1
print("正确率:" + str(correct) + "/" + str(len(data)))
#画图
colors = mpl.colors.ListedColormap(['#FFC4B2', '#B1FFC3', '#B5FFFE'])
x, y = np.meshgrid(np.linspace(4, 8, 100), np.linspace(1.5, 4.5, 100));
z = np.zeros(x.shape)
for i in range(100):
for j in range(100):
z[i, j] = tree.classify([x[i, j], y[i, j]])
plt.pcolormesh(x, y, z, cmap = colors)
plt.scatter(data[0], data[1], c = data[4])
plt.show()