对鸢尾花卉数据集训练决策树
python 实现,代码比较乱
import random
import math
import copy
import sys
class decisionTree (object):
def __init__(self, file, label):
self.file = file
self.label = label
self.dataset = self.initDataset()
self.readfile()
def initDataset(self):
dataset={}
for i in range(len(self.label)):
dataset[self.label[i]] = []
return dataset
def readfile(self):
myfile = open(self.file, 'r')
for line in myfile:
line = line.strip()
data = line.split(',')
if data[-1] != '':
self.dataset[data[-1]].append(data[:-1])
for i in range(len(self.label)):
for j in range(len(self.dataset[self.label[i]])):
for m in range(len(self.dataset[self.label[i]][j])):
self.dataset[self.label[i]][j][m] = float(self.dataset[self.label[i]][j][m])
# 样本lable所占比例 为了取样时各类样本均衡
def k_cross(self,k):
num = []
for i in range(len(self.label)):
num.append(len(self.dataset[self.label[i]]))
all = 0
for i in range(len(num)):
all += num[i]
for i in range(len(num)):
num[i] = num[i] / k
for i in range(len(num)):
if num[i] - int(num[i]) > 0.5 :
num[i] = int(num[i]) + 1
else :
num[i] = int(num[i])
#进行k折样本取样
data = []
for i in range(k):
data.append([])
all_list = []
for i in range(len(self.label)):
all_list.append([])
for i in range(len(self.label)):
start =0
for j in range(len(self.dataset[self.label[i]])):
all_list[i].append(start)
start += 1
for k in range(k):
for i in range(len(self.label)):
list = random.sample(all_list[i],num[i])
all_list[i] = set(all_list[i])^set(list)
for j in range(len(list)):
pop_data = self.dataset[self.label[i]][list[j]]
pop_data.append(self.label[i])
data[k].append(pop_data)
return data
# 返回所有k 类的样本 [[[train],[test]],[[train],[test]],[[train],[test]]........
def k_data(self,k):
dataset = self.k_cross(k)
data = []
for i in range(len(dataset)):
data.append([[],[]]) #返回k组 [[训练集] ,[测试集]] 对
for i in range(len(dataset)):
data[i][1] = dataset[i]
for j in range(len(dataset)):
if j != i :
for k in range(len(dataset[j])):
data[i][0].append(dataset[j][k])
return data
# 此处的dataset为应该是 自己给的训练集 应该为一个list [[],[],[],[],[],[],[],[],[],[],[],]这种格式 然后返回自己定义的树结构
def train_tree(self,dataset):
#计算信息增益
def ent(dataset):
num=[]
for i in range(len(self.label)):
num.append(0)
for i in range(len(dataset)):
for j in range(len(self.label)):
if dataset[i][-1] == self.label[j] :
num[j] += 1
all = 0
for i in range(len(num)):
all += num[i]
ent_data = 0
for i in range(len(num)):
if num[i] != 0:
ent_data -= num[i]/all * math.log2(num[i]/all)
return ent_data
def Gain(dataset,root,key): #key 为第几类特征
def Gain_sub(dataset,root,key,a): # a 为第几个化分点
ent_data = ent(dataset)
sub_data_1 = []
sub_data_2 = []
for i in range(len(dataset)):
if dataset[i][key] < root[key][a]:
sub_data_1.append(dataset[i])
else :
sub_data_2.append(dataset[i])
gain_data = 0
gain_data = ent_data - (len(sub_data_1)/len(dataset)) * ent(sub_data_1) - (len(sub_data_2)/len(dataset)) * ent(sub_data_2)
return gain_data
gain=[]
for i in range(len(root[key])):
gain.append(Gain_sub(dataset,root,key,i))
return max(gain),gain.index(max(gain))
def next_opr(dataset):
# 对于这个dataset中的每种特征进行信息增益计算 结构为双重字典 第一个key为第几类特征 第二个key为划分点 最后value为 信息增益
# 对特征的属性值总结
feature = {}
for i in range(len(dataset[0]) - 1):
feature[i] = []
for i in range(len(dataset)):
for j in range(len(dataset[i]) - 1):
if dataset[i][j] not in feature[j]:
feature[j].append(dataset[i][j])
for i in range(len(feature.keys())):
feature[i] = sorted(feature[i])
# 划分连续值根节点
root = {}
for i in range(len(feature.keys())):
root[i] = []
for i in range(len(feature.keys())):
for j in range(len(feature[i])):
if j != len(feature[i]) - 1:
root[i].append((feature[i][j] + feature[i][j + 1]) / 2)
gain = {}
for i in range(len(root.keys())):
gain[i] = {}
for i in range(len(gain.keys())):
for j in range(len(root[i])):
gain_data, k = Gain(dataset, root, i)
gain[i][root[i][k]] = gain_data
return gain
#以list为最后树的储存方式 [root,[],[]] list[0]为root list[1] 为左子树 list[2]为右子树 若没有label 则子节点存储dataset
def key_root(my_gain):
#对于存储有信息增益的结构进行解析 返回key 第几类特征 root 划分点
key_1 = list(my_gain.keys())
max = 0
key = 0
root = 0
for i in range(len(my_gain)):
key_2 = list(my_gain[key_1[i]].keys())
for j in range(len(key_2)):
if my_gain[key_1[i]][key_2[j]] > max:
max = my_gain[key_1[i]][key_2[j]]
key = key_1[i]
root = key_2[j]
return key,root
tree = []
my_gain = next_opr(dataset)
key,root = key_root(my_gain)
tree.append([key,root])
tree.append([])
tree.append([])
#通过key root 划分剩余数据集 dataset
def sub_root(dataset,key,root):
sub_left = []
sub_right = []
for i in range(len(dataset)):
if dataset[i][key] < root:
sub_left.append(dataset[i])
else:
sub_right.append(dataset[i])
return sub_left,sub_right
sub_left, sub_right = sub_root(dataset,key,root)
tree[1] = sub_left
tree[2] = sub_right
#检测左子树右子树中的样本是否为同一label
def test(dataset):
num=[]
for i in range(len(self.label)):
num.append(0)
for i in range(len(dataset)):
for j in range(len(self.label)):
if dataset[i][-1] == self.label[j]:
num[j] +=1
if max(num) == len(dataset):
return self.label[num.index(max(num))]
else:
return dataset
def next(tree):
for i in range(len(tree)):
if i != 0:
tree[i] = test(tree[i])
for i in range(len(tree)):
if i != 0:
if tree[i] not in self.label:
dataset = tree[i]
tree[i] = []
gains = next_opr(dataset)
key,root = key_root(gains)
tree[i].append([key,root])
tree[i].append([])
tree[i].append([])
left,right = sub_root(dataset,key,root)
tree[i][1] = test(left)
tree[i][2] = test(right)
next(tree[i])
next(tree)
return tree
#进行预剪枝的训练树
def pre_pruning(self,train_dataset,test_dataset):
def next_opr(dataset):
# 对于这个dataset中的每种特征进行信息增益计算 结构为双重字典 第一个key为第几类特征 第二个key为划分点 最后value为 信息增益
def Gain(dataset, root, key): # key 为第几类特征
# 计算信息增益
def ent(dataset):
num = []
for i in range(len(self.label)):
num.append(0)
for i in range(len(dataset)):
for j in range(len(self.label)):
if dataset[i][-1] == self.label[j]:
num[j] += 1
all = 0
for i in range(len(num)):
all += num[i]
ent_data = 0
for i in range(len(num)):
if num[i] != 0:
ent_data -= num[i] / all * math.log2(num[i] / all)
return ent_data
def Gain_sub(dataset, root, key, a): # a 为第几个化分点
ent_data = ent(dataset)
sub_data_1 = []
sub_data_2 = []
for i in range(len(dataset)):
if dataset[i][key] < root[key][a]:
sub_data_1.append(dataset[i])
else:
sub_data_2.append(dataset[i])
gain_data = 0
gain_data = ent_data - (len(sub_data_1) / len(dataset)) * ent(sub_data_1) - (len(sub_data_2) / len(
dataset)) * ent(sub_data_2)
return gain_data
gain = []
for i in range(len(root[key])):
gain.append(Gain_sub(dataset, root, key, i))
return max(gain), gain.index(max(gain))
# 对特征的属性值总结
feature = {}
for i in range(len(dataset[0]) - 1):
feature[i] = []
for i in range(len(dataset)):
for j in range(len(dataset[i]) - 1):
if dataset[i][j] not in feature[j]:
feature[j].append(dataset[i][j])
for i in range(len(feature.keys())):
feature[i] = sorted(feature[i])
# 划分连续值根节点
root = {}
for i in range(len(feature.keys())):
root[i] = []
for i in range(len(feature.keys())):
for j in range(len(feature[i])):
if j != len(feature[i]) - 1:
root[i].append((feature[i][j] + feature[i][j + 1]) / 2)
gain = {}
for i in range(len(root.keys())):
gain[i] = {}
for i in range(len(gain.keys())):
for j in range(len(root[i])):
gain_data, k = Gain(dataset, root, i)
gain[i][root[i][k]] = gain_data
return gain
# 以list为最后树的储存方式 [root,[],[]] list[0]为root list[1] 为左子树 list[2]为右子树 若没有label 则子节点存储dataset
def key_root(my_gain):
# 对于存储有信息增益的结构进行解析 返回最大信息增益的key root 返回key 第几类特征 root 划分点 返回
key_1 = list(my_gain.keys())
max = 0
key = 0
root = 0
for i in range(len(my_gain)):
key_2 = list(my_gain[key_1[i]].keys())
for j in range(len(key_2)):
if my_gain[key_1[i]][key_2[j]] > max:
max = my_gain[key_1[i]][key_2[j]]
key = key_1[i]
root = key_2[j]
return key, root
# 通过key root 划分剩余数据集 dataset
def sub_root(dataset, key, root):
sub_left = []
sub_right = []
for i in range(len(dataset)):
if dataset[i][key] < root:
sub_left.append(dataset[i])
else:
sub_right.append(dataset[i])
return sub_left, sub_right
def max_lable(dataset):
# 返回label 样本数 最大的 label
if type(dataset) == type([]):
num = []
for i in range(len(self.label)):
num.append(0)
for i in range(len(dataset)):
for j in range(len(self.label)):
if dataset[i][-1] == self.label[j]:
num[j] += 1
max = 0
for i in range(len(num)):
if num[i] > max:
max = num[i]
return self.label[num.index(max)]
elif type(dataset) == type('abc'):
return dataset
def test_tree(tree,dataset):
#这棵树来验证 dataset的准确度
def test_label(train_tree,test_data):
#用树来验证这个数据是否验证正确
label = None
if test_data[train_tree[0][0]] < train_tree[0][1]:
if train_tree[1] not in self.label:
train_tree = train_tree[1]
label = test_label(train_tree, test_data)
else:
label = train_tree[1]
else:
if train_tree[2] not in self.label:
train_tree = train_tree[2]
label = test_label(train_tree, test_data)
else:
label = train_tree[2]
if label == test_data[-1]:
return True
else:
return False
all_num = len(dataset)
right = 0
for i in range(len(dataset)):
if test_label(tree,dataset[i]):
right += 1
return right/all_num
# 初始化树 也就是未进行第一个根节点划分前
tree = [[0,0],max_lable(train_dataset),max_lable(train_dataset)]
tree_data = train_dataset
#初始化进行根节点划分操作的树
next_tree = []
next_tree_data=[]
my_gain = next_opr(train_dataset)
key, root = key_root(my_gain)
next_tree.append([key, root])
next_tree.append([])
next_tree.append([])
next_tree_data.append([key,root])
next_tree_data.append([])
next_tree_data.append([])
sub_left, sub_right = sub_root(train_dataset, key, root)
next_tree[1] = max_lable(sub_left)
next_tree[2] = max_lable(sub_right)
next_tree_data[1] = sub_left
next_tree_data[2] = sub_right
# 检测左子树右子树中的样本是否为同一label
def test(dataset):
num = []
for i in range(len(self.label)):
num.append(0)
for i in range(len(dataset)):
for j in range(len(self.label)):
if dataset[i][-1] == self.label[j]:
num[j] += 1
if max(num) == len(dataset):
return self.label[num.index(max(num))]
else:
return dataset
#传入两颗树 然后判断是否进行操作
def pruning(first,first_data,next,next_data,test_dataset):
def next_tree(tree_data,bool):
#看是不是符合剪枝操作
#这里应该传入的是 next_data
if bool :
for i in range(len(tree_data)):
if i != 0:
tree_data[i] = test(tree_data[i])
for i in range(len(tree_data)):
if i != 0 and type(tree_data[i]) == type([]):
next_data = copy.deepcopy(tree_data) # 这样两个数据就不会存在同一快内存地址
if tree_data[i] not in self.label:
dataset = tree_data[i]
tree_data[i] = []
gains = next_opr(dataset)
key, root = key_root(gains)
tree_data[i].append([key, root])
tree_data[i].append([])
tree_data[i].append([])
next_data[i] =[]
next_data[i].append([key,root])
next_data[i].append([])
next_data[i].append([])
left, right = sub_root(dataset, key, root)
next_data[i][1] = left
next_data[i][2] = right
tree_data[i][1] = max_lable(left)
tree_data[i][2] = max_lable(right)
break;
else :
for i in range(len(tree_data)):
if i != 0:
tree_data[i] = test(tree_data[i])
for i in range(len(tree_data)):
if i != 0 and type(tree_data[i]) == type([]):
next_data = copy.copy(tree_data)
dataset = None
if type(tree_data[i][1]) == type('str') and type(tree_data[i][2]) == type([]):
str = tree_data[i][1]
dataset = tree_data[i][2].append(str)
elif type(tree_data[i][2]) == type('str') and type(tree_data[i][1]) == type([]):
str = tree_data[i][2]
dataset = tree_data[i][1].append(str)
else:
dataset = tree_data[i][1] + tree_data[i][2]
next_data[i] = max_lable(dataset)
break;
for i in range(len(tree_data)):
if i != 0 and type(tree_data[i]) == type([]):
dataset = None
if type(tree_data[i][1]) == type('str') and type(tree_data[i][2]) == type([]):
dataset = tree_data[i][2].append(tree_data[i][1])
elif type(tree_data[i][2]) == type('str') and type(tree_data[i][1]) == type([]):
dataset = tree_data[i][1].append(tree_data[i][2])
else:
dataset = tree_data[i][1] + tree_data[i][2]
tree_data[i] = max_lable(dataset)
return tree_data,next_data
while(next != next_data):
if test_tree(first, test_dataset) < test_tree(next, test_dataset):
first = next
first_data = copy.deepcopy(next_data)
next, next_data = next_tree(next_data, True)
else:
next, next_data = next_tree(next_data, False)
return next
new_tree = pruning(tree,tree_data,next_tree,next_tree_data,test_dataset)
return new_tree
def post_pruning(self,train_dataset,test_dataset):
# 计算信息增益
def ent(dataset):
num = []
for i in range(len(self.label)):
num.append(0)
for i in range(len(dataset)):
for j in range(len(self.label)):
if dataset[i][-1] == self.label[j]:
num[j] += 1
all = 0
for i in range(len(num)):
all += num[i]
ent_data = 0
for i in range(len(num)):
if num[i] != 0:
ent_data -= num[i] / all * math.log2(num[i] / all)
return ent_data
def Gain(dataset, root, key): # key 为第几类特征
def Gain_sub(dataset, root, key, a): # a 为第几个化分点
ent_data = ent(dataset)
sub_data_1 = []
sub_data_2 = []
for i in range(len(dataset)):
if dataset[i][key] < root[key][a]:
sub_data_1.append(dataset[i])
else:
sub_data_2.append(dataset[i])
gain_data = 0
gain_data = ent_data - (len(sub_data_1) / len(dataset)) * ent(sub_data_1) - (len(sub_data_2) / len(
dataset)) * ent(sub_data_2)
return gain_data
gain = []
for i in range(len(root[key])):
gain.append(Gain_sub(dataset, root, key, i))
return max(gain), gain.index(max(gain))
def next_opr(dataset):
# 对于这个dataset中的每种特征进行信息增益计算 结构为双重字典 第一个key为第几类特征 第二个key为划分点 最后value为 信息增益
# 对特征的属性值总结
feature = {}
for i in range(len(dataset[0]) - 1):
feature[i] = []
for i in range(len(dataset)):
for j in range(len(dataset[i]) - 1):
if dataset[i][j] not in feature[j]:
feature[j].append(dataset[i][j])
for i in range(len(feature.keys())):
feature[i] = sorted(feature[i])
# 划分连续值根节点
root = {}
for i in range(len(feature.keys())):
root[i] = []
for i in range(len(feature.keys())):
for j in range(len(feature[i])):
if j != len(feature[i]) - 1:
root[i].append((feature[i][j] + feature[i][j + 1]) / 2)
gain = {}
for i in range(len(root.keys())):
gain[i] = {}
for i in range(len(gain.keys())):
for j in range(len(root[i])):
gain_data, k = Gain(dataset, root, i)
gain[i][root[i][k]] = gain_data
return gain
# 以list为最后树的储存方式 [root,[],[]] list[0]为root list[1] 为左子树 list[2]为右子树 若没有label 则子节点存储dataset
def key_root(my_gain):
# 对于存储有信息增益的结构进行解析 返回key 第几类特征 root 划分点
key_1 = list(my_gain.keys())
max = 0
key = 0
root = 0
for i in range(len(my_gain)):
key_2 = list(my_gain[key_1[i]].keys())
for j in range(len(key_2)):
if my_gain[key_1[i]][key_2[j]] > max:
max = my_gain[key_1[i]][key_2[j]]
key = key_1[i]
root = key_2[j]
return key, root
tree = []
my_gain = next_opr(train_dataset)
key, root = key_root(my_gain)
tree.append([key, root])
tree.append([])
tree.append([])
# 通过key root 划分剩余数据集 dataset
def sub_root(dataset, key, root):
sub_left = []
sub_right = []
for i in range(len(dataset)):
if dataset[i][key] < root:
sub_left.append(dataset[i])
else:
sub_right.append(dataset[i])
return sub_left, sub_right
sub_left, sub_right = sub_root(train_dataset, key, root)
tree[1] = sub_left
tree[2] = sub_right
# 检测左子树右子树中的样本是否为同一label
def test(dataset):
num = []
for i in range(len(self.label)):
num.append(0)
for i in range(len(dataset)):
for j in range(len(self.label)):
if dataset[i][-1] == self.label[j]:
num[j] += 1
if max(num) == len(dataset):
return self.label[num.index(max(num))]
else:
return dataset
def next(tree):
for i in range(len(tree)):
if i != 0:
tree[i] = test(tree[i])
for i in range(len(tree)):
if i != 0:
if tree[i] not in self.label:
dataset = tree[i]
tree[i] = []
gains = next_opr(dataset)
key, root = key_root(gains)
tree[i].append([key, root])
tree[i].append([])
tree[i].append([])
left, right = sub_root(dataset, key, root)
tree[i][1] = test(left)
tree[i][2] = test(right)
next(tree[i])
next(tree)
tree_data = copy.deepcopy(tree)
def clear_lable_add_data(tree_data,train_dataset):
def clear_lable(tree_data):
for i in range(len(tree_data)):
if i!= 0:
if type(tree_data[i]) == type('str'):
tree_data[i] = []
elif type(tree_data) == type([]):
clear_lable(tree_data[i])
clear_lable(tree_data)
def put_data(tree,data):
if data[tree[0][0]] < tree[0][1] :
if len(tree[1]) != 3 :
tree[1].append(data)
else:
if len(tree[1][0]) == 2:
put_data(tree[1],data)
else:
tree[1].append(data)
elif data[tree[0][0]] > tree[0][1] :
if len(tree[2]) != 3 :
tree[2].append(data)
else:
if len(tree[2][0]) == 2:
put_data(tree[2],data)
else:
tree[2].append(data)
for i in range(len(train_dataset)):
put_data(tree_data,train_dataset[i])
return tree_data
tree_data = clear_lable_add_data(tree_data,train_dataset)
def post_purning(tree,tree_data,test_dataset):
def test_tree(tree, dataset):
# 这棵树来验证 dataset的准确度
def test_label(train_tree, test_data):
# 用树来验证这个数据是否验证正确
label = None
if test_data[train_tree[0][0]] < train_tree[0][1]:
if train_tree[1] not in self.label:
train_tree = train_tree[1]
label = test_label(train_tree, test_data)
else:
label = train_tree[1]
return label
else:
if train_tree[2] not in self.label:
train_tree = train_tree[2]
label = test_label(train_tree, test_data)
else:
label = train_tree[2]
return label
return label
all_num = len(dataset)
right = 0
for i in range(len(dataset)):
if test_label(tree, dataset[i]) == dataset[i][-1]:
right += 1
return right / all_num
def branchs(tree):
def branch(tree):
list = []
for i in range(len(tree)):
if len(tree[i]) == 3 and len(tree[i][0]) == 2 :
if len(tree[i][1]) == 3 or len(tree[i][2]) ==3:
list.append(i)
list.append(branch(tree[i]))
elif len(tree[i][1][0]) != 2 and len(tree[i][2][0]) !=2:
list.append(i)
return list
branch_dict = branch(tree)
return branch_dict
def process(br):
all =[]
def all_num(list):
key = 0
for i in range(len(list)):
if type(list[i]) == type(list):
return False
else:
key+=1
if key == len(list):
return True
def empty_list(list):
#判断是否有空的list
all_key = 0
for i in range(len(list)):
if list[i] == []:
return True
else:
all_key +=1
if all_key == len(list):
return False
def br_tree(br, list):
for i in range(len(br)):
if type(br[i]) == type([]) and len(br[i]) != 1:
br_tree(br[i], list)
if empty_list(br[i]):
list.insert(0, br[i][br[i].index([])-1])
br[i].pop(br[i].index([]))
list.insert(0, br[i - 1])
else:
list.insert(0,br[i-1])
break
elif type(br[i]) == type([]) and len(br[i]) == 1:
num = br[i].pop(0)
list.append(num)
break
elif type(br[i]) == type([]) and len(br[i]) == 0:
br.pop(br.index([]))
br.pop(-1)
elif all_num(br):
num = br.pop(0)
list.append(num)
break
elif type(br[0]) == type(1) and type(br[1]) == type([]) and br[1] == []:
br.pop(-1)
break
while (br != []):
list = []
br_tree(br, list)
if len(list) == 0:
list.append(br[0])
br.pop(-1)
elif len(list) == 1:
list.insert(0,br[0])
all.append(list)
return all
def max_lable(dataset):
# 返回label 样本数 最大的 label
if type(dataset) == type([]):
num = []
for i in range(len(self.label)):
num.append(0)
for i in range(len(dataset)):
for j in range(len(self.label)):
if dataset[i][-1] == self.label[j]:
num[j] += 1
max = 0
for i in range(len(num)):
if num[i] > max:
max = num[i]
return self.label[num.index(max)]
elif type(dataset) == type('abc'):
return dataset
def tree_lable(tree_data):
for i in range(len(tree_data)):
if i != 0 and len(tree_data[i]) ==3 and len(tree_data[i][0]) == 2:
tree_data[i]=tree_lable(tree_data[i])
elif i != 0 and len(tree_data[i]) !=3 :
tree_data[i] = max_lable(tree_data[i])
elif i !=0 and len(tree_data[i]) ==3 and len(tree_data[i][0]) != 2:
tree_data[i] = max_lable(tree_data[i])
return tree_data
br = branchs(tree)
first = tree
first_data = tree_data
br = process(br)
def tree_list(tree_data,list):
#使用一个列表来对树进行剪枝操作
def the_Data(tree):
the_da = []
for i in range(len(tree)):
if len(tree[i]) == 3 and len(tree[i][0]) == 2:
the_da.extend(the_Data(tree[i]))
elif i != 0 and len(tree[i]) !=3 :
the_da.extend(tree[i])
elif i != 0 and len(tree[i]) == 3 and type(tree[i][0][-1]) == type('str'):
the_da.extend(tree[i])
return the_da
def dir(tree_data,list):
if len(list) == 0 :
tree_data = the_Data(tree_data)
return tree_data
else:
tree_data[list[0]] = dir(tree_data[list[0]],list[1:])
return tree_data
dir(tree_data,list)
return tree_data
# 利用处理后的所有剪枝节点序列 来进行剪枝操作
for i in range(len(br)):
second_data = copy.deepcopy(first_data)
tree_list(second_data,br[i])
data = copy.deepcopy(second_data)
second = tree_lable(data)
firs_acc = test_tree(first,test_dataset)
sec_acc = test_tree(second,test_dataset)
# print(firs_acc)
# print(sec_acc)
if firs_acc <= sec_acc:
first = second
first_data = second_data
return first
tree = post_purning(tree,tree_data,test_dataset)
return tree
def label_sample(self,train_tree,test_data):
def process(train_tree,test_data):
label = None
if test_data[train_tree[0][0]] < train_tree[0][1]:
if train_tree[1] not in self.label:
train_tree = train_tree[1]
label = process(train_tree,test_data)
else:
label = train_tree[1]
else:
if train_tree[2] not in self.label:
train_tree = train_tree[2]
label =process(train_tree,test_data)
else:
label = train_tree[2]
return label
label = process(train_tree,test_data)
# print(test_data[:-1],'\'s true label:',test_data[-1],' predict label is :',label)
return label
def k_accuracy(self,data):
def accuracy(train,test):
acc = 0
my_tree = self.train_tree(train)
for i in range(len(test)):
if test[i][-1] == self.label_sample(my_tree,test[i]):
acc += 1
return acc/len(test)
for i in range(len(data)):
print('No-pruning Processing',i+1,'batch.....')
print('The accuracy is',accuracy(data[i][0],data[i][1]),'.....')
print('Batch',i+1,'is finished.....')
print('The purning is finished..')
print('*****************************************')
def pre_accuracy(self,data):
def accuracy(train,test):
acc = 0
my_tree = self.pre_pruning(train,test)
for i in range(len(test)):
if test[i][-1] == self.label_sample(my_tree,test[i]):
acc += 1
return acc/len(test)
for i in range(len(data)):
print('Pre-pruning Processing',i+1,'batch.....')
print('The accuracy is',accuracy(data[i][0],data[i][1]),'.....')
print('Batch',i+1,'is finished.....')
print('The purning is finished..')
print('*****************************************')
def post_accuracy(self,data):
def accuracy(train,test):
acc = 0
my_tree = self.post_pruning(train,test)
for i in range(len(test)):
if test[i][-1] == self.label_sample(my_tree,test[i]):
acc += 1
return acc/len(test)
for i in range(len(data)):
print('Post-pruning Processing',i+1,'batch.....')
print('The accuracy is',accuracy(data[i][0],data[i][1]),'.....')
print('Batch',i+1,'is finished.....')
print('The purning is finished..')
print('*****************************************')
if __name__ == '__main__':
label = ['Iris-setosa','Iris-versicolor','Iris-virginica']
tree = decisionTree('iris.data',label)
k_data = tree.k_data(5)
tree.k_accuracy(k_data)
tree.pre_accuracy(k_data)
tree.post_accuracy(k_data)
以下是数据集
鸢尾花卉Iris数据集描述:
iris是鸢尾植物,这里存储了其萼片和花瓣的长宽,共4个属性,鸢尾植物分三类。所以该数据集一共包含4个特征变量,1个类别变量。共有150个样本,鸢尾有三个亚属,分别是山鸢尾 (Iris-setosa),变色鸢尾(Iris-versicolor)和维吉尼亚鸢尾(Iris-virginica)。
也就是说我们的数据集里每个样本含有四个属性,并且我们的任务是个三分类问题。三个类别分别为:Iris Setosa(山鸢尾),Iris Versicolour(杂色鸢尾),Iris Virginica(维吉尼亚鸢尾)。
例如:
样本一:5.1, 3.5, 1.4, 0.2, Iris-setosa
其中“5.1,3.5,1.4,0.2”代表当前样本的四个属性的取值,“Iris-setosa”代表当前样本的类别。
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica