机器学习实战 CART部分,后剪枝代码

① readdata(): 读取数据
② TreeNode: 构建树, 但后面都用字典进行构建了
③ binary_split(): 按照最优切分值,把data进行切分
④ choosebestsplit(); 选择最好的特征进行拆分 遍历每一个feature中每一个切分值,找到loss最小的
⑤ creattree: 构建树,按照bestfeature, bestval,进行递归构建树, 调用了 ③ ④两个功能

⑥istree: 判断是否是字典类型, 从而判断是否是一棵树
⑦ getmean: 对于切分后没有数据,只剩下树的,进行合并处理,用该树的平均值作为输出
⑧ prune: 进行剪枝处理,
判断1: 如果有一棵是树,对数据进行切分
判断2: 如果是树,递归进行剪枝,判断是否需要合并
判断3: 如果都是节点,不是树了,判断是否需要合并
如果都是树,return

import numpy as np


def readdata(filename):
    f = open(filename, 'r')
    mat = []
    for line in f.readlines():
        data = line.strip().split('\t')
        mat.append(data)
    return np.array(mat, dtype=np.float32)


class TreeNode:
    def __init__(self, feature, left, right, val):
        self.feature = feature
        self.left = left
        self.right = right
        self.val = val


def binary_split(dataset, feature_index, value): # 根据第index个特征的val值对dataset进行划分
    mat0 = dataset[dataset[:, feature_index] <= value]
    mat1 = dataset[dataset[:, feature_index] > value]
    return mat0, mat1


def cal_mean(dataset):
    return np.mean(dataset[:, -1])


def cal_loss(dataset):
    return np.var(dataset[:, -1])  # 先对y求均值,然后求与y均方差,相当于就是直接求均方差


def choosebestsplit(dataset, regLeaf=cal_mean, regerr=cal_loss, ops=(1, 4)):
    m, n = len(dataset), len(dataset[0])
    tops, topN = ops[0], ops[1]
    S = cal_loss(dataset)
    best_s = np.inf
    bestfeature = 0
    best_val = 0
    for featindex in range(n-1):
        for val in set(dataset[:, featindex]):
            mat0, mat1 = binary_split(dataset, featindex, val)
            if mat0.shape[0] < topN or mat1.shape[0] < topN:
                continue # 是当前区域rm的均值,不能理解是原来总体的数据
            temp_err = regerr(mat0) + regerr(mat1)
            if temp_err < best_s:
                best_s = temp_err
                bestfeature = featindex
                best_val = val
    if S - best_s < tops:
        return None, regLeaf(dataset)
    # 不可能存在 mat0 和 mat1< topN的情况吧,毕竟不符合都跳过了
    return bestfeature, best_val

def creattree(dataset, regLeaf=cal_mean, regerr=cal_loss, ops=(1, 4)):
    '''
    :param regLeaf:  用于计算输出,回归树中就是均值,或者线性函数也行
    :param regerr:   用于计算损失值,标准可以自己设定
    :param ops:    #  第一个变量 1 用于控制 损失下降,下降 < 1,return, 第二个用于控制节点数目,节点太少,return
    :return:   一棵树
    '''
    bestfeature, best_val = choosebestsplit(dataset, cal_mean, cal_loss, ops)
    if bestfeature is None:
        return best_val
    regtree = {}
    regtree['feature'] = bestfeature
    regtree['val'] = best_val
    mat0, mat1 = binary_split(dataset, bestfeature, best_val)
    regtree['left'] = creattree(mat0, cal_mean, cal_loss, ops)
    regtree['right'] = creattree(mat1, cal_mean, cal_loss, ops)
    return regtree


def istree(obj):
    return (type(obj).__name__ == 'dict')


def getmean(tree):
    '''
    :param tree: 传进来只会是树,prune中判断完是树才执行, if 传入一棵树,递归到它的子节点,进行平均值递归输出
    :return:     输出左右节点的平均值
    '''
    if istree(obj): tree['right'] = getmean(tree['right'])
    if istree(obj): tree['left'] = getmean(tree['left'])
    return (tree['right'] +tree['left']) / 2


def prune(tree, testdata):
    if testdata.shape[0] == 0: return getmean(tree)   # 没数据了,直接合并
    if istree(tree['left']) or istree(tree['right']):  # 如果两个都是值呢,就不切分了
        left, right = binary_split(testdata, tree['feature'], tree['val'])
    if istree(tree['left']): tree['left'] = prune(tree['left'], left)
    if istree(tree['right']): tree['right'] = prune(tree['right'], right)
    if not istree(tree['left']) and not istree(tree['right']):
        # 损失函数的计算,根据测试集,计算合并前后,输出与测试集输出的均方差
        left, right = binary_split(testdata, tree['feature'], tree['val'])
        nomerge_loss = np.sum((left[:, -1] - tree['left'])**2) + np.sum((right[:, -1] - tree['right'])**2)
        merge_mean = (tree['left'] + tree['right']) / 2
        merge_loss = np.sum((testdata[:, -1] - merge_mean)**2)
        if nomerge_loss > merge_mean:   # 损失大了,进行合并
            print('merge')
            return merge_mean
        else:
            return tree
    else: # 如果有一棵是树,一个是值,递归下去,直到都是值,后面会进行回溯合并的
        return tree


ops = (0, 1)
data = readdata(r'ex2.txt')
whole_tree = creattree(data, cal_mean, cal_loss, ops)
print(whole_tree)
test = readdata(r'ex2test.txt')
print(prune(whole_tree, testdata=test))

```python
在这里插入代码片

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值