python构建k-d Tree

此处不细细讲解,只做简单构造k-d Tree

选轴策略:

1.依次X,Y,Z....K轴划分

2.按方差大的轴划分,及数据分散的轴

3.按数据差值大的轴划分(本文以此策略划分)

构造k-d Tree

def get_kdtree(tree_list):
    list_lend = len(tree_list)

    if list_lend == 0:
        return None
    elif list_lend == 1:
        return {"value": tree_list[0], "left": None, "right": None}
    else:
        # 取最大差值为轴

        xlist = []
        ylist = []
        for i in tree_list:
            x, y = i
            xlist.append(x)
            ylist.append(y)
        xmax = max(xlist)
        xmin = min(xlist)
        ymax = max(ylist)
        ymin = min(ylist)
        if (xmax - xmin) >= (ymax - ymin):

            tree_list.sort(key=lambda i: i[0])

        else:
            tree_list.sort(key=lambda i: i[1])
        midNum = list_lend // 2
        left_list = tree_list[0:midNum]
        right_list = tree_list[midNum + 1:]

        return {"value": tree_list[midNum], "left": get_kdtree(left_list), "right": get_kdtree(right_list)}
  
        

if __name__ == '__main__':
    point_list = [(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)]

    kdtree = get_kdtree(point_list)
    print(kdtree)

输出结果

{'value': (7, 2), 'left': {'value': (4, 7), 'left': {'value': (2, 3), 'left': None, 'right': None}, 'right': {'value': (5, 4), 'left': None, 'right': None}}, 'right': {'value': (9, 6), 'left': {'value': (8, 1), 'left': None, 'right': None}, 'right': None}}

对于符号回归问题,SR-Tree的代码实现如下: ```python import numpy as np from scipy.spatial.distance import cdist class Node: def __init__(self, data=None, bounds=None, children=None, isleaf=True): self.data = data or [] self.bounds = bounds self.children = children or [] self.isleaf = isleaf class SR_Tree: def __init__(self, data, labels, thres=1, maxdepth=5): self.data = data self.labels = labels self.thres = thres self.maxdepth = maxdepth self.root = self.build_tree(self.data, self.labels, depth=0) def build_tree(self, data, labels, depth): if len(data) <= self.thres or depth >= self.maxdepth: return Node(data, bounds=self.get_bounds(data), isleaf=True) else: n = Node(bounds=self.get_bounds(data), isleaf=False) children = self.split_data(data, labels) for child_data, child_labels in children: child_node = self.build_tree(child_data, child_labels, depth+1) n.children.append(child_node) return n def split_data(self, data, labels): n = len(labels) dim = data.shape[1] k = np.random.randint(dim) pivot = np.median(data[:, k]) left = [] right = [] for i in range(n): if data[i, k] < pivot: left.append((data[i], labels[i])) else: right.append((data[i], labels[i])) return [np.array([d for d, _ in left]), np.array([l for _, l in left])], \ [np.array([d for d, _ in right]), np.array([l for _, l in right])] def query(self, x): node = self.root while not node.isleaf: dists = [self.get_dist(x, child.bounds) for child in node.children] idx = np.argmin(dists) node = node.children[idx] return np.mean(node.data[:, -1]) def get_dist(self, x, bounds): if x.ndim == 1: x = x.reshape(1, -1) dist = 0.0 for i in range(x.shape[0]): for j in range(x.shape[1]): if x[i, j] < bounds[j, 0]: dist += (bounds[j, 0] - x[i, j]) ** 2 elif x[i, j] > bounds[j, 1]: dist += (x[i, j] - bounds[j, 1]) ** 2 return dist def get_bounds(self, data): return np.array([[np.min(data[:, j]), np.max(data[:, j])] for j in range(data.shape[1])]) def generate_data(n=1000): x = np.random.uniform(-10, 10, (n, 2)) y = 0.5 * x[:, 0] ** 2 - 0.3 * x[:, 1] ** 2 + 2 * x[:, 0] - 3 * x[:, 1] + 5 + np.random.normal(0, 0.5, n) return np.hstack((x, y.reshape(-1, 1))) if __name__ == '__main__': data = generate_data() sr_tree = SR_Tree(data[:, :-1], data[:, -1], thres=50, maxdepth=10) x = np.array([-3, 5]) y = sr_tree.query(x) print('Result:', y) ``` 在这个代码实现中,我们将符号回归问题的数据和标签作为SR-Tree的输入,然后构建SR-Tree来进行查询。在SR-Tree构建过程中,我们按照随机选择轴的策略将数据集分成两个子集,然后递归地构建SR-Tree。在查询时,我们从根节点开始遍历树,根据查询点和每个节点的边界计算距离,并移动到最近的子节点,直到达到叶节点为止,然后返回该叶节点中数据的均值作为预测结果。 在这个实现中,我们使用了numpy和scipy库来进行计算和距离计算。为了生成测试数据,我们定义了一个简单的二次函数,并添加了随机噪声。最后,我们使用SR-Tree来进行查询,以测试其性能。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值