① readdata(): 读取数据
② TreeNode: 构建树, 但后面都用字典进行构建了
③ binary_split(): 按照最优切分值,把data进行切分
④ choosebestsplit(); 选择最好的特征进行拆分 遍历每一个feature中每一个切分值,找到loss最小的
⑤ creattree: 构建树,按照bestfeature, bestval,进行递归构建树, 调用了 ③ ④两个功能
⑥istree: 判断是否是字典类型, 从而判断是否是一棵树
⑦ getmean: 对于切分后没有数据,只剩下树的,进行合并处理,用该树的平均值作为输出
⑧ prune: 进行剪枝处理,
判断1: 如果有一棵是树,对数据进行切分
判断2: 如果是树,递归进行剪枝,判断是否需要合并
判断3: 如果都是节点,不是树了,判断是否需要合并
如果都是树,return
import numpy as np
def readdata(filename):
f = open(filename, 'r')
mat = []
for line in f.readlines():
data = line.strip().split('\t')
mat.append(data)
return np.array(mat, dtype=np.float32)
class TreeNode:
def __init__(self, feature, left, right, val):
self.feature = feature
self.left = left
self.right = right
self.val = val
def binary_split(dataset, feature_index, value): # 根据第index个特征的val值对dataset进行划分
mat0 = dataset[dataset[:, feature_index] <= value]
mat1 = dataset[dataset[:, feature_index] > value]
return mat0, mat1
def cal_mean(dataset):
return np.mean(dataset[:, -1])
def cal_loss(dataset):
return np.var(dataset[:, -1]) # 先对y求均值,然后求与y均方差,相当于就是直接求均方差
def choosebestsplit(dataset, regLeaf=cal_mean, regerr=cal_loss, ops=(1, 4)):
m, n = len(dataset), len(dataset[0])
tops, topN = ops[0], ops[1]
S = cal_loss(dataset)
best_s = np.inf
bestfeature = 0
best_val = 0
for featindex in range(n-1):
for val in set(dataset[:, featindex]):
mat0, mat1 = binary_split(dataset, featindex, val)
if mat0.shape[0] < topN or mat1.shape[0] < topN:
continue # 是当前区域rm的均值,不能理解是原来总体的数据
temp_err = regerr(mat0) + regerr(mat1)
if temp_err < best_s:
best_s = temp_err
bestfeature = featindex
best_val = val
if S - best_s < tops:
return None, regLeaf(dataset)
# 不可能存在 mat0 和 mat1< topN的情况吧,毕竟不符合都跳过了
return bestfeature, best_val
def creattree(dataset, regLeaf=cal_mean, regerr=cal_loss, ops=(1, 4)):
'''
:param regLeaf: 用于计算输出,回归树中就是均值,或者线性函数也行
:param regerr: 用于计算损失值,标准可以自己设定
:param ops: # 第一个变量 1 用于控制 损失下降,下降 < 1,return, 第二个用于控制节点数目,节点太少,return
:return: 一棵树
'''
bestfeature, best_val = choosebestsplit(dataset, cal_mean, cal_loss, ops)
if bestfeature is None:
return best_val
regtree = {}
regtree['feature'] = bestfeature
regtree['val'] = best_val
mat0, mat1 = binary_split(dataset, bestfeature, best_val)
regtree['left'] = creattree(mat0, cal_mean, cal_loss, ops)
regtree['right'] = creattree(mat1, cal_mean, cal_loss, ops)
return regtree
def istree(obj):
return (type(obj).__name__ == 'dict')
def getmean(tree):
'''
:param tree: 传进来只会是树,prune中判断完是树才执行, if 传入一棵树,递归到它的子节点,进行平均值递归输出
:return: 输出左右节点的平均值
'''
if istree(obj): tree['right'] = getmean(tree['right'])
if istree(obj): tree['left'] = getmean(tree['left'])
return (tree['right'] +tree['left']) / 2
def prune(tree, testdata):
if testdata.shape[0] == 0: return getmean(tree) # 没数据了,直接合并
if istree(tree['left']) or istree(tree['right']): # 如果两个都是值呢,就不切分了
left, right = binary_split(testdata, tree['feature'], tree['val'])
if istree(tree['left']): tree['left'] = prune(tree['left'], left)
if istree(tree['right']): tree['right'] = prune(tree['right'], right)
if not istree(tree['left']) and not istree(tree['right']):
# 损失函数的计算,根据测试集,计算合并前后,输出与测试集输出的均方差
left, right = binary_split(testdata, tree['feature'], tree['val'])
nomerge_loss = np.sum((left[:, -1] - tree['left'])**2) + np.sum((right[:, -1] - tree['right'])**2)
merge_mean = (tree['left'] + tree['right']) / 2
merge_loss = np.sum((testdata[:, -1] - merge_mean)**2)
if nomerge_loss > merge_mean: # 损失大了,进行合并
print('merge')
return merge_mean
else:
return tree
else: # 如果有一棵是树,一个是值,递归下去,直到都是值,后面会进行回溯合并的
return tree
ops = (0, 1)
data = readdata(r'ex2.txt')
whole_tree = creattree(data, cal_mean, cal_loss, ops)
print(whole_tree)
test = readdata(r'ex2test.txt')
print(prune(whole_tree, testdata=test))
```python
在这里插入代码片