机器学习小实验 决策树实验

对鸢尾花卉数据集训练决策树

python 实现,代码比较乱 

import random
import math
import copy
import sys


class decisionTree (object):
    def __init__(self, file, label):
        self.file = file
        self.label = label
        self.dataset = self.initDataset()
        self.readfile()
    def initDataset(self):
        dataset={}
        for i in range(len(self.label)):
            dataset[self.label[i]] = []
        return dataset
    def readfile(self):
        myfile = open(self.file, 'r')
        for line in myfile:
            line = line.strip()
            data = line.split(',')
            if data[-1] != '':
                self.dataset[data[-1]].append(data[:-1])
        for i in range(len(self.label)):
            for j in range(len(self.dataset[self.label[i]])):
                for m in range(len(self.dataset[self.label[i]][j])):
                    self.dataset[self.label[i]][j][m] = float(self.dataset[self.label[i]][j][m])

    # 样本lable所占比例 为了取样时各类样本均衡
    def k_cross(self,k):

        num = []
        for i in range(len(self.label)):
            num.append(len(self.dataset[self.label[i]]))
        all = 0
        for i in range(len(num)):
            all += num[i]
        for i in range(len(num)):
            num[i] = num[i] / k

        for i in range(len(num)):
            if num[i] - int(num[i]) > 0.5 :
                num[i] = int(num[i]) + 1
            else :
                num[i] = int(num[i])


        #进行k折样本取样
        data = []
        for i in range(k):
            data.append([])
        all_list = []
        for i in range(len(self.label)):
            all_list.append([])
        for i in range(len(self.label)):
            start =0
            for j in range(len(self.dataset[self.label[i]])):
                all_list[i].append(start)
                start += 1
        for k in range(k):
            for i in range(len(self.label)):
                list = random.sample(all_list[i],num[i])
                all_list[i] = set(all_list[i])^set(list)
                for j in range(len(list)):
                    pop_data = self.dataset[self.label[i]][list[j]]
                    pop_data.append(self.label[i])
                    data[k].append(pop_data)
        return data

    # 返回所有k 类的样本 [[[train],[test]],[[train],[test]],[[train],[test]]........
    def k_data(self,k):

        dataset = self.k_cross(k)
        data = []
        for i in range(len(dataset)):
            data.append([[],[]])         #返回k组 [[训练集] ,[测试集]] 对
        for i in range(len(dataset)):
            data[i][1] = dataset[i]
            for j in range(len(dataset)):
                if j != i :
                    for  k in range(len(dataset[j])):
                        data[i][0].append(dataset[j][k])
        return data

    # 此处的dataset为应该是 自己给的训练集 应该为一个list [[],[],[],[],[],[],[],[],[],[],[],]这种格式  然后返回自己定义的树结构
    def train_tree(self,dataset):
        #计算信息增益
        def ent(dataset):
            num=[]
            for i in range(len(self.label)):
                num.append(0)
            for i in range(len(dataset)):
                for j in range(len(self.label)):
                     if dataset[i][-1] == self.label[j] :
                        num[j] += 1
            all = 0
            for i in range(len(num)):
                all += num[i]
            ent_data = 0
            for i in range(len(num)):
                if num[i] != 0:
                    ent_data -= num[i]/all * math.log2(num[i]/all)
            return ent_data
        def Gain(dataset,root,key):  #key 为第几类特征
            def Gain_sub(dataset,root,key,a):  # a 为第几个化分点
                ent_data = ent(dataset)
                sub_data_1 = []
                sub_data_2 = []
                for i in range(len(dataset)):
                    if  dataset[i][key] < root[key][a]:
                        sub_data_1.append(dataset[i])
                    else :
                        sub_data_2.append(dataset[i])
                gain_data = 0
                gain_data = ent_data - (len(sub_data_1)/len(dataset)) * ent(sub_data_1) - (len(sub_data_2)/len(dataset)) * ent(sub_data_2)
                return  gain_data

            gain=[]
            for i in range(len(root[key])):
                gain.append(Gain_sub(dataset,root,key,i))
            return max(gain),gain.index(max(gain))

        def next_opr(dataset):
            # 对于这个dataset中的每种特征进行信息增益计算 结构为双重字典  第一个key为第几类特征 第二个key为划分点 最后value为 信息增益
            # 对特征的属性值总结
            feature = {}
            for i in range(len(dataset[0]) - 1):
                feature[i] = []
            for i in range(len(dataset)):
                for j in range(len(dataset[i]) - 1):
                    if dataset[i][j] not in feature[j]:
                        feature[j].append(dataset[i][j])
            for i in range(len(feature.keys())):
                feature[i] = sorted(feature[i])
            # 划分连续值根节点
            root = {}
            for i in range(len(feature.keys())):
                root[i] = []
            for i in range(len(feature.keys())):
                for j in range(len(feature[i])):
                    if j != len(feature[i]) - 1:
                        root[i].append((feature[i][j] + feature[i][j + 1]) / 2)
            gain = {}
            for i in range(len(root.keys())):
                gain[i] = {}
            for i in range(len(gain.keys())):
                for j in range(len(root[i])):
                    gain_data, k = Gain(dataset, root, i)
                    gain[i][root[i][k]] = gain_data
            return gain

        #以list为最后树的储存方式   [root,[],[]]    list[0]为root list[1] 为左子树 list[2]为右子树 若没有label 则子节点存储dataset

        def key_root(my_gain):
            #对于存储有信息增益的结构进行解析 返回key 第几类特征 root 划分点
            key_1 = list(my_gain.keys())
            max = 0
            key = 0
            root = 0
            for i in range(len(my_gain)):
                key_2 = list(my_gain[key_1[i]].keys())
                for j in range(len(key_2)):
                    if my_gain[key_1[i]][key_2[j]] > max:
                        max = my_gain[key_1[i]][key_2[j]]
                        key = key_1[i]
                        root = key_2[j]
            return key,root

        tree = []
        my_gain = next_opr(dataset)
        key,root = key_root(my_gain)
        tree.append([key,root])
        tree.append([])
        tree.append([])

        #通过key  root 划分剩余数据集 dataset
        def sub_root(dataset,key,root):
            sub_left = []
            sub_right = []
            for i in range(len(dataset)):
                if dataset[i][key] < root:
                    sub_left.append(dataset[i])
                else:
                    sub_right.append(dataset[i])
            return sub_left,sub_right

        sub_left, sub_right = sub_root(dataset,key,root)
        tree[1] = sub_left
        tree[2] = sub_right
        #检测左子树右子树中的样本是否为同一label
        def test(dataset):
            num=[]
            for i in range(len(self.label)):
                num.append(0)
            for i in range(len(dataset)):
                for j in range(len(self.label)):
                    if dataset[i][-1] == self.label[j]:
                        num[j] +=1
            if  max(num) == len(dataset):
                return self.label[num.index(max(num))]
            else:
                return dataset

        def next(tree):
            for i in range(len(tree)):
                if i != 0:
                    tree[i] = test(tree[i])
            for i in range(len(tree)):
                if i != 0:
                    if tree[i] not in self.label:
                        dataset = tree[i]
                        tree[i] = []
                        gains = next_opr(dataset)
                        key,root = key_root(gains)
                        tree[i].append([key,root])
                        tree[i].append([])
                        tree[i].append([])
                        left,right = sub_root(dataset,key,root)
                        tree[i][1] = test(left)
                        tree[i][2] = test(right)
                        next(tree[i])
        next(tree)
        return tree
    #进行预剪枝的训练树
    def pre_pruning(self,train_dataset,test_dataset):

        def next_opr(dataset):
            # 对于这个dataset中的每种特征进行信息增益计算 结构为双重字典  第一个key为第几类特征 第二个key为划分点 最后value为 信息增益
            def Gain(dataset, root, key):  # key 为第几类特征
                # 计算信息增益
                def ent(dataset):
                    num = []
                    for i in range(len(self.label)):
                        num.append(0)
                    for i in range(len(dataset)):
                        for j in range(len(self.label)):
                            if dataset[i][-1] == self.label[j]:
                                num[j] += 1
                    all = 0
                    for i in range(len(num)):
                        all += num[i]
                    ent_data = 0
                    for i in range(len(num)):
                        if num[i] != 0:
                            ent_data -= num[i] / all * math.log2(num[i] / all)
                    return ent_data

                def Gain_sub(dataset, root, key, a):  # a 为第几个化分点
                    ent_data = ent(dataset)
                    sub_data_1 = []
                    sub_data_2 = []
                    for i in range(len(dataset)):
                        if dataset[i][key] < root[key][a]:
                            sub_data_1.append(dataset[i])
                        else:
                            sub_data_2.append(dataset[i])
                    gain_data = 0
                    gain_data = ent_data - (len(sub_data_1) / len(dataset)) * ent(sub_data_1) - (len(sub_data_2) / len(
                        dataset)) * ent(sub_data_2)
                    return gain_data

                gain = []
                for i in range(len(root[key])):
                    gain.append(Gain_sub(dataset, root, key, i))
                return max(gain), gain.index(max(gain))
            # 对特征的属性值总结
            feature = {}
            for i in range(len(dataset[0]) - 1):
                feature[i] = []
            for i in range(len(dataset)):
                for j in range(len(dataset[i]) - 1):
                    if dataset[i][j] not in feature[j]:
                        feature[j].append(dataset[i][j])
            for i in range(len(feature.keys())):
                feature[i] = sorted(feature[i])
            # 划分连续值根节点
            root = {}
            for i in range(len(feature.keys())):
                root[i] = []
            for i in range(len(feature.keys())):
                for j in range(len(feature[i])):
                    if j != len(feature[i]) - 1:
                        root[i].append((feature[i][j] + feature[i][j + 1]) / 2)
            gain = {}
            for i in range(len(root.keys())):
                gain[i] = {}
            for i in range(len(gain.keys())):
                for j in range(len(root[i])):
                    gain_data, k = Gain(dataset, root, i)
                    gain[i][root[i][k]] = gain_data
            return gain
        # 以list为最后树的储存方式   [root,[],[]]    list[0]为root list[1] 为左子树 list[2]为右子树 若没有label 则子节点存储dataset
        def key_root(my_gain):
            # 对于存储有信息增益的结构进行解析  返回最大信息增益的key root 返回key 第几类特征 root 划分点  返回

            key_1 = list(my_gain.keys())
            max = 0
            key = 0
            root = 0
            for i in range(len(my_gain)):
                key_2 = list(my_gain[key_1[i]].keys())
                for j in range(len(key_2)):
                    if my_gain[key_1[i]][key_2[j]] > max:
                        max = my_gain[key_1[i]][key_2[j]]
                        key = key_1[i]
                        root = key_2[j]
            return key, root

            # 通过key  root 划分剩余数据集 dataset
        def sub_root(dataset, key, root):
            sub_left = []
            sub_right = []
            for i in range(len(dataset)):
                if dataset[i][key] < root:
                    sub_left.append(dataset[i])
                else:
                    sub_right.append(dataset[i])
            return sub_left, sub_right
        def max_lable(dataset):
            # 返回label 样本数 最大的 label
            if type(dataset) == type([]):
                num = []
                for i in range(len(self.label)):
                    num.append(0)
                for i in range(len(dataset)):
                    for j in range(len(self.label)):
                        if dataset[i][-1] == self.label[j]:
                            num[j] += 1
                max = 0
                for i in range(len(num)):
                    if num[i] > max:
                        max = num[i]
                return self.label[num.index(max)]
            elif type(dataset) == type('abc'):
                return dataset
        def test_tree(tree,dataset):
            #这棵树来验证 dataset的准确度
            def test_label(train_tree,test_data):
                #用树来验证这个数据是否验证正确
                label = None
                if test_data[train_tree[0][0]] < train_tree[0][1]:
                    if train_tree[1] not in self.label:
                        train_tree = train_tree[1]
                        label = test_label(train_tree, test_data)
                    else:
                        label = train_tree[1]
                else:
                    if train_tree[2] not in self.label:
                        train_tree = train_tree[2]
                        label = test_label(train_tree, test_data)
                    else:
                        label = train_tree[2]
                if label == test_data[-1]:
                    return True
                else:
                    return False
            all_num = len(dataset)
            right = 0
            for i in  range(len(dataset)):
                if  test_label(tree,dataset[i]):
                    right += 1
            return right/all_num

        # 初始化树 也就是未进行第一个根节点划分前
        tree = [[0,0],max_lable(train_dataset),max_lable(train_dataset)]
        tree_data = train_dataset
        #初始化进行根节点划分操作的树
        next_tree = []
        next_tree_data=[]
        my_gain = next_opr(train_dataset)
        key, root = key_root(my_gain)
        next_tree.append([key, root])
        next_tree.append([])
        next_tree.append([])
        next_tree_data.append([key,root])
        next_tree_data.append([])
        next_tree_data.append([])
        sub_left, sub_right = sub_root(train_dataset, key, root)
        next_tree[1] = max_lable(sub_left)
        next_tree[2] = max_lable(sub_right)
        next_tree_data[1] = sub_left
        next_tree_data[2] = sub_right

        # 检测左子树右子树中的样本是否为同一label
        def test(dataset):
            num = []
            for i in range(len(self.label)):
                num.append(0)
            for i in range(len(dataset)):
                for j in range(len(self.label)):
                    if dataset[i][-1] == self.label[j]:
                        num[j] += 1
            if max(num) == len(dataset):
                return self.label[num.index(max(num))]
            else:
                return dataset

        #传入两颗树 然后判断是否进行操作
        def pruning(first,first_data,next,next_data,test_dataset):
            def next_tree(tree_data,bool):
                #看是不是符合剪枝操作
                #这里应该传入的是 next_data
                if bool :
                    for i in range(len(tree_data)):
                        if i != 0:
                            tree_data[i] = test(tree_data[i])
                    for i in range(len(tree_data)):
                        if i != 0 and type(tree_data[i]) == type([]):
                            next_data = copy.deepcopy(tree_data)  # 这样两个数据就不会存在同一快内存地址
                            if tree_data[i] not in self.label:

                                dataset = tree_data[i]
                                tree_data[i] = []
                                gains = next_opr(dataset)
                                key, root = key_root(gains)

                                tree_data[i].append([key, root])
                                tree_data[i].append([])
                                tree_data[i].append([])

                                next_data[i] =[]
                                next_data[i].append([key,root])
                                next_data[i].append([])
                                next_data[i].append([])

                                left, right = sub_root(dataset, key, root)
                                next_data[i][1] = left
                                next_data[i][2] = right
                                tree_data[i][1] = max_lable(left)
                                tree_data[i][2] = max_lable(right)
                                break;
                else :
                    for i in range(len(tree_data)):
                        if i != 0:
                            tree_data[i] = test(tree_data[i])
                    for i in range(len(tree_data)):
                        if i != 0 and type(tree_data[i]) == type([]):
                            next_data = copy.copy(tree_data)
                            dataset = None
                            if type(tree_data[i][1]) == type('str') and type(tree_data[i][2]) == type([]):
                                str = tree_data[i][1]
                                dataset = tree_data[i][2].append(str)
                            elif type(tree_data[i][2]) == type('str') and type(tree_data[i][1]) == type([]):
                                str = tree_data[i][2]
                                dataset = tree_data[i][1].append(str)
                            else:
                                dataset = tree_data[i][1] + tree_data[i][2]
                            next_data[i] = max_lable(dataset)
                            break;
                    for i in range(len(tree_data)):
                        if i != 0 and type(tree_data[i]) == type([]):
                            dataset = None
                            if type(tree_data[i][1]) == type('str') and type(tree_data[i][2]) == type([]):
                                dataset = tree_data[i][2].append(tree_data[i][1])
                            elif type(tree_data[i][2]) == type('str') and type(tree_data[i][1]) == type([]):
                                dataset = tree_data[i][1].append(tree_data[i][2])
                            else:
                                dataset = tree_data[i][1] + tree_data[i][2]
                            tree_data[i] = max_lable(dataset)

                return tree_data,next_data

            while(next != next_data):
                if test_tree(first, test_dataset) < test_tree(next, test_dataset):
                    first = next
                    first_data = copy.deepcopy(next_data)
                    next, next_data = next_tree(next_data, True)

                else:
                    next, next_data = next_tree(next_data, False)
            return next
        new_tree = pruning(tree,tree_data,next_tree,next_tree_data,test_dataset)
        return new_tree
    def post_pruning(self,train_dataset,test_dataset):
        # 计算信息增益
        def ent(dataset):
            num = []
            for i in range(len(self.label)):
                num.append(0)
            for i in range(len(dataset)):
                for j in range(len(self.label)):
                    if dataset[i][-1] == self.label[j]:
                        num[j] += 1
            all = 0
            for i in range(len(num)):
                all += num[i]
            ent_data = 0
            for i in range(len(num)):
                if num[i] != 0:
                    ent_data -= num[i] / all * math.log2(num[i] / all)
            return ent_data

        def Gain(dataset, root, key):  # key 为第几类特征
            def Gain_sub(dataset, root, key, a):  # a 为第几个化分点
                ent_data = ent(dataset)
                sub_data_1 = []
                sub_data_2 = []
                for i in range(len(dataset)):
                    if dataset[i][key] < root[key][a]:
                        sub_data_1.append(dataset[i])
                    else:
                        sub_data_2.append(dataset[i])
                gain_data = 0
                gain_data = ent_data - (len(sub_data_1) / len(dataset)) * ent(sub_data_1) - (len(sub_data_2) / len(
                    dataset)) * ent(sub_data_2)
                return gain_data

            gain = []
            for i in range(len(root[key])):
                gain.append(Gain_sub(dataset, root, key, i))
            return max(gain), gain.index(max(gain))

        def next_opr(dataset):
            # 对于这个dataset中的每种特征进行信息增益计算 结构为双重字典  第一个key为第几类特征 第二个key为划分点 最后value为 信息增益
            # 对特征的属性值总结
            feature = {}
            for i in range(len(dataset[0]) - 1):
                feature[i] = []
            for i in range(len(dataset)):
                for j in range(len(dataset[i]) - 1):
                    if dataset[i][j] not in feature[j]:
                        feature[j].append(dataset[i][j])
            for i in range(len(feature.keys())):
                feature[i] = sorted(feature[i])
            # 划分连续值根节点
            root = {}
            for i in range(len(feature.keys())):
                root[i] = []
            for i in range(len(feature.keys())):
                for j in range(len(feature[i])):
                    if j != len(feature[i]) - 1:
                        root[i].append((feature[i][j] + feature[i][j + 1]) / 2)
            gain = {}
            for i in range(len(root.keys())):
                gain[i] = {}
            for i in range(len(gain.keys())):
                for j in range(len(root[i])):
                    gain_data, k = Gain(dataset, root, i)
                    gain[i][root[i][k]] = gain_data
            return gain

        # 以list为最后树的储存方式   [root,[],[]]    list[0]为root list[1] 为左子树 list[2]为右子树 若没有label 则子节点存储dataset

        def key_root(my_gain):
            # 对于存储有信息增益的结构进行解析 返回key 第几类特征 root 划分点
            key_1 = list(my_gain.keys())
            max = 0
            key = 0
            root = 0
            for i in range(len(my_gain)):
                key_2 = list(my_gain[key_1[i]].keys())
                for j in range(len(key_2)):
                    if my_gain[key_1[i]][key_2[j]] > max:
                        max = my_gain[key_1[i]][key_2[j]]
                        key = key_1[i]
                        root = key_2[j]
            return key, root

        tree = []
        my_gain = next_opr(train_dataset)
        key, root = key_root(my_gain)
        tree.append([key, root])
        tree.append([])
        tree.append([])

        # 通过key  root 划分剩余数据集 dataset
        def sub_root(dataset, key, root):
            sub_left = []
            sub_right = []
            for i in range(len(dataset)):
                if dataset[i][key] < root:
                    sub_left.append(dataset[i])
                else:
                    sub_right.append(dataset[i])
            return sub_left, sub_right

        sub_left, sub_right = sub_root(train_dataset, key, root)
        tree[1] = sub_left
        tree[2] = sub_right

        # 检测左子树右子树中的样本是否为同一label
        def test(dataset):
            num = []
            for i in range(len(self.label)):
                num.append(0)
            for i in range(len(dataset)):
                for j in range(len(self.label)):
                    if dataset[i][-1] == self.label[j]:
                        num[j] += 1
            if max(num) == len(dataset):
                return self.label[num.index(max(num))]
            else:
                return dataset

        def next(tree):
            for i in range(len(tree)):
                if i != 0:
                    tree[i] = test(tree[i])
            for i in range(len(tree)):
                if i != 0:
                    if tree[i] not in self.label:
                        dataset = tree[i]
                        tree[i] = []
                        gains = next_opr(dataset)
                        key, root = key_root(gains)
                        tree[i].append([key, root])
                        tree[i].append([])
                        tree[i].append([])
                        left, right = sub_root(dataset, key, root)
                        tree[i][1] = test(left)
                        tree[i][2] = test(right)
                        next(tree[i])

        next(tree)
        tree_data = copy.deepcopy(tree)
        def clear_lable_add_data(tree_data,train_dataset):
            def clear_lable(tree_data):
                for i in range(len(tree_data)):
                    if i!= 0:
                        if type(tree_data[i]) == type('str'):
                            tree_data[i] = []
                        elif type(tree_data) == type([]):
                            clear_lable(tree_data[i])
            clear_lable(tree_data)
            def put_data(tree,data):
                if data[tree[0][0]] < tree[0][1] :
                    if len(tree[1]) != 3 :
                        tree[1].append(data)
                    else:
                        if len(tree[1][0]) == 2:
                            put_data(tree[1],data)
                        else:
                            tree[1].append(data)
                elif data[tree[0][0]] > tree[0][1] :
                    if len(tree[2]) != 3 :
                        tree[2].append(data)
                    else:
                        if len(tree[2][0]) == 2:
                            put_data(tree[2],data)
                        else:
                            tree[2].append(data)
            for i in range(len(train_dataset)):
                put_data(tree_data,train_dataset[i])
            return tree_data

        tree_data = clear_lable_add_data(tree_data,train_dataset)
        def post_purning(tree,tree_data,test_dataset):
            def test_tree(tree, dataset):
                # 这棵树来验证 dataset的准确度
                def test_label(train_tree, test_data):
                    # 用树来验证这个数据是否验证正确
                    label = None
                    if test_data[train_tree[0][0]] < train_tree[0][1]:
                        if train_tree[1] not in self.label:
                            train_tree = train_tree[1]
                            label = test_label(train_tree, test_data)
                        else:
                            label = train_tree[1]
                            return label
                    else:
                        if train_tree[2] not in self.label:
                            train_tree = train_tree[2]
                            label = test_label(train_tree, test_data)
                        else:
                            label = train_tree[2]
                            return label
                    return label
                all_num = len(dataset)
                right = 0
                for i in range(len(dataset)):
                    if test_label(tree, dataset[i]) == dataset[i][-1]:
                        right += 1
                return right / all_num
            def branchs(tree):

                def branch(tree):
                    list = []
                    for i in range(len(tree)):
                        if len(tree[i]) == 3 and len(tree[i][0]) == 2 :
                            if len(tree[i][1]) == 3 or len(tree[i][2]) ==3:
                                list.append(i)
                                list.append(branch(tree[i]))
                            elif len(tree[i][1][0]) != 2 and len(tree[i][2][0]) !=2:
                                list.append(i)

                    return list
                branch_dict = branch(tree)
                return branch_dict
            def process(br):
                all =[]
                def all_num(list):
                    key = 0
                    for i in range(len(list)):
                        if type(list[i]) == type(list):
                            return False
                        else:
                            key+=1
                    if key == len(list):
                        return True
                def empty_list(list):
                    #判断是否有空的list
                    all_key = 0
                    for i in range(len(list)):
                        if  list[i] == []:
                            return True
                        else:
                            all_key +=1
                    if all_key == len(list):
                        return False
                def br_tree(br, list):
                    for i in range(len(br)):
                        if type(br[i]) == type([]) and len(br[i]) != 1:
                            br_tree(br[i], list)
                            if empty_list(br[i]):
                                list.insert(0, br[i][br[i].index([])-1])
                                br[i].pop(br[i].index([]))
                                list.insert(0, br[i - 1])
                            else:
                                list.insert(0,br[i-1])
                            break
                        elif type(br[i]) == type([]) and len(br[i]) == 1:
                            num = br[i].pop(0)
                            list.append(num)
                            break
                        elif type(br[i]) == type([]) and len(br[i]) == 0:
                            br.pop(br.index([]))
                            br.pop(-1)
                        elif all_num(br):
                            num = br.pop(0)
                            list.append(num)
                            break
                        elif type(br[0]) == type(1) and type(br[1]) == type([]) and br[1] == []:
                            br.pop(-1)
                            break
                while (br != []):
                    list = []
                    br_tree(br, list)
                    if len(list) == 0:
                        list.append(br[0])
                        br.pop(-1)
                    elif len(list) == 1:
                        list.insert(0,br[0])
                    all.append(list)
                return all
            def max_lable(dataset):
                # 返回label 样本数 最大的 label
                if type(dataset) == type([]):
                    num = []
                    for i in range(len(self.label)):
                        num.append(0)
                    for i in range(len(dataset)):
                        for j in range(len(self.label)):
                            if dataset[i][-1] == self.label[j]:
                                num[j] += 1
                    max = 0
                    for i in range(len(num)):
                        if num[i] > max:
                            max = num[i]
                    return self.label[num.index(max)]
                elif type(dataset) == type('abc'):
                    return dataset
            def tree_lable(tree_data):
                for i in range(len(tree_data)):
                    if i != 0 and len(tree_data[i]) ==3 and len(tree_data[i][0]) == 2:
                        tree_data[i]=tree_lable(tree_data[i])
                    elif i != 0 and len(tree_data[i]) !=3 :
                        tree_data[i] = max_lable(tree_data[i])
                    elif i !=0 and  len(tree_data[i]) ==3 and  len(tree_data[i][0]) != 2:
                        tree_data[i] = max_lable(tree_data[i])
                return tree_data
            br = branchs(tree)
            first = tree
            first_data = tree_data
            br = process(br)

            def tree_list(tree_data,list):
                #使用一个列表来对树进行剪枝操作
                def the_Data(tree):
                    the_da = []
                    for i in range(len(tree)):
                        if len(tree[i]) == 3 and len(tree[i][0]) == 2:
                            the_da.extend(the_Data(tree[i]))
                        elif i != 0 and len(tree[i]) !=3 :
                            the_da.extend(tree[i])
                        elif i != 0 and len(tree[i]) == 3 and type(tree[i][0][-1]) == type('str'):
                            the_da.extend(tree[i])
                    return the_da
                def dir(tree_data,list):
                    if len(list) == 0 :
                        tree_data = the_Data(tree_data)
                        return tree_data
                    else:
                        tree_data[list[0]] = dir(tree_data[list[0]],list[1:])
                        return tree_data
                dir(tree_data,list)
                return  tree_data

            # 利用处理后的所有剪枝节点序列    来进行剪枝操作
            for i in range(len(br)):
                second_data = copy.deepcopy(first_data)
                tree_list(second_data,br[i])
                data = copy.deepcopy(second_data)
                second = tree_lable(data)
                firs_acc = test_tree(first,test_dataset)
                sec_acc = test_tree(second,test_dataset)
                # print(firs_acc)
                # print(sec_acc)
                if firs_acc <= sec_acc:
                    first = second
                    first_data = second_data
            return first

        tree = post_purning(tree,tree_data,test_dataset)
        return tree

    def label_sample(self,train_tree,test_data):
        def process(train_tree,test_data):
            label = None
            if test_data[train_tree[0][0]] < train_tree[0][1]:
                if train_tree[1] not in self.label:
                    train_tree = train_tree[1]
                    label = process(train_tree,test_data)
                else:
                    label = train_tree[1]
            else:
                if train_tree[2] not in self.label:
                    train_tree = train_tree[2]
                    label =process(train_tree,test_data)
                else:
                    label = train_tree[2]
            return label
        label = process(train_tree,test_data)
        # print(test_data[:-1],'\'s true label:',test_data[-1],' predict label is :',label)
        return label

    def k_accuracy(self,data):
        def accuracy(train,test):
            acc = 0
            my_tree = self.train_tree(train)
            for i in range(len(test)):
                if test[i][-1] == self.label_sample(my_tree,test[i]):
                    acc += 1
            return acc/len(test)
        for i in range(len(data)):
            print('No-pruning Processing',i+1,'batch.....')
            print('The accuracy is',accuracy(data[i][0],data[i][1]),'.....')
            print('Batch',i+1,'is finished.....')
            print('The purning is finished..')
        print('*****************************************')
    def pre_accuracy(self,data):
        def accuracy(train,test):
            acc = 0
            my_tree = self.pre_pruning(train,test)
            for i in range(len(test)):
                if test[i][-1] == self.label_sample(my_tree,test[i]):
                    acc += 1
            return acc/len(test)
        for i in range(len(data)):
            print('Pre-pruning Processing',i+1,'batch.....')
            print('The accuracy is',accuracy(data[i][0],data[i][1]),'.....')
            print('Batch',i+1,'is finished.....')
            print('The purning is finished..')
        print('*****************************************')
    def post_accuracy(self,data):
        def accuracy(train,test):
            acc = 0
            my_tree = self.post_pruning(train,test)
            for i in range(len(test)):
                if test[i][-1] == self.label_sample(my_tree,test[i]):
                    acc += 1
            return acc/len(test)
        for i in range(len(data)):
            print('Post-pruning Processing',i+1,'batch.....')
            print('The accuracy is',accuracy(data[i][0],data[i][1]),'.....')
            print('Batch',i+1,'is finished.....')
            print('The purning is finished..')
        print('*****************************************')
if __name__ == '__main__':
    label = ['Iris-setosa','Iris-versicolor','Iris-virginica']
    tree = decisionTree('iris.data',label)

    k_data = tree.k_data(5)

    tree.k_accuracy(k_data)
    tree.pre_accuracy(k_data)
    tree.post_accuracy(k_data)


以下是数据集

鸢尾花卉Iris数据集描述:

iris是鸢尾植物,这里存储了其萼片和花瓣的长宽,共4个属性,鸢尾植物分三类。所以该数据集一共包含4个特征变量,1个类别变量。共有150个样本,鸢尾有三个亚属,分别是山鸢尾 (Iris-setosa),变色鸢尾(Iris-versicolor)和维吉尼亚鸢尾(Iris-virginica)。

也就是说我们的数据集里每个样本含有四个属性,并且我们的任务是个三分类问题。三个类别分别为:Iris Setosa(山鸢尾),Iris Versicolour(杂色鸢尾),Iris Virginica(维吉尼亚鸢尾)。

例如:

样本一:5.1, 3.5, 1.4, 0.2, Iris-setosa

其中“5.1,3.5,1.4,0.2”代表当前样本的四个属性的取值,“Iris-setosa”代表当前样本的类别。


5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica




  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值