数据结构系列:线段树 python版

前言

本篇归纳线段树

线段树

几篇博文归纳的很好

模板

class SegmentTree:
    def __init__(self, data, merge): 
        '''
        data:传入的数组
        merge:处理的业务逻辑,例如求和/最大值/最小值,lambda表达式
        '''

        self.data = data
        self.n = len(data)
        #  申请4倍data长度的空间来存线段树节点
        self.tree = [None] * (4 * self.n) # 索引i的左孩子索引为2i+1,右孩子为2i+2
        self._merge = merge
        if self.n:
            self._build(0, 0, self.n-1)


    def query(self, ql, qr):
        '''
        返回区间[ql,..,qr]的值
        '''
        return self._query(0, 0, self.n-1, ql, qr)

    def update(self, index, value):
        # 将data数组index位置的值更新为value,然后递归更新线段树中被影响的各节点的值
        self.data[index] = value
        self._update(0, 0, self.n-1, index)

    def _build(self, tree_index, l, r):
        '''
        递归创建线段树
        tree_index : 线段树节点在数组中位置
        l, r : 该节点表示的区间的左,右边界
        '''
        if l == r:
            self.tree[tree_index] = self.data[l]
            return
        mid = (l+r) // 2 # 区间中点,对应左孩子区间结束,右孩子区间开头
        left, right = 2 * tree_index + 1, 2 * tree_index + 2 # tree_index的左右子树索引
        self._build(left, l, mid)
        self._build(right, mid+1, r)
        self.tree[tree_index] = self._merge(self.tree[left], self.tree[right])

    def _query(self, tree_index, l, r, ql, qr):
        '''
        递归查询区间[ql,..,qr]的值
        tree_index : 某个根节点的索引
        l, r : 该节点表示的区间的左右边界
        ql, qr: 待查询区间的左右边界
        '''
        if l == ql and r == qr:
            return self.tree[tree_index]

        mid = (l+r) // 2 # 区间中点,对应左孩子区间结束,右孩子区间开头
        left, right = tree_index * 2 + 1, tree_index * 2 + 2
        if qr <= mid:
            # 查询区间全在左子树
            return self._query(left, l, mid, ql, qr)
        elif ql > mid:
            # 查询区间全在右子树
            return self._query(right, mid+1, r, ql, qr)

        # 查询区间一部分在左子树一部分在右子树
        return self._merge(self._query(left, l, mid, ql, mid), 
                          self._query(right, mid+1, r, mid+1, qr))

    def _update(self, tree_index, l, r, index):
        '''
        tree_index:某个根节点索引
        l, r : 此根节点代表区间的左右边界
        index : 更新的值的索引
        '''
        if l == r == index:
            self.tree[tree_index] = self.data[index]
            return
        mid = (l+r)//2
        left, right = 2 * tree_index + 1, 2 * tree_index + 2
        if index > mid:
            # 要更新的区间在右子树
            self._update(right, mid+1, r, index)
        else:
            # 要更新的区间在左子树index<=mid
            self._update(left, l, mid, index)
        # 里面的小区间变化了,包裹的大区间也要更新
        self.tree[tree_index] = self._merge(self.tree[left], self.tree[right])

例子

区域和检索 - 数组可修改

leet上307题

给你一个数组 nums ,请你完成两类查询,其中一类查询要求更新数组下标对应的值,另一类查询要求返回数组中某个范围内元素的总和。
实现 NumArray 类:
NumArray(int[] nums) 用整数数组 nums 初始化对象
void update(int index, int val) 将 nums[index] 的值更新为 val
int sumRange(int left, int right) 返回子数组 nums[left, right] 的总和(即,nums[left] + nums[left + 1], ..., nums[right]
class NumArray:

    def __init__(self, nums: List[int]):
        self.segment_tree = SegmentTree(nums, lambda x, y : x + y)
        

    def update(self, i: int, val: int) -> None:
        self.segment_tree.update(i, val)
        

    def sumRange(self, i: int, j: int) -> int:
        return self.segment_tree.query(i, j)

class SegmentTree:
    def __init__(self, data, merge): 
        '''
        data:传入的数组
        merge:处理的业务逻辑,例如求和/最大值/最小值,lambda表达式
        '''

        self.data = data
        self.n = len(data)
        #  申请4倍data长度的空间来存线段树节点
        self.tree = [None] * (4 * self.n) # 索引i的左孩子索引为2i+1,右孩子为2i+2
        self._merge = merge
        if self.n:
            self._build(0, 0, self.n-1)


    def query(self, ql, qr):
        '''
        返回区间[ql,..,qr]的值
        '''
        return self._query(0, 0, self.n-1, ql, qr)

    def update(self, index, value):
        # 将data数组index位置的值更新为value,然后递归更新线段树中被影响的各节点的值
        self.data[index] = value
        self._update(0, 0, self.n-1, index)

    def _build(self, tree_index, l, r):
        '''
        递归创建线段树
        tree_index : 线段树节点在数组中位置
        l, r : 该节点表示的区间的左,右边界
        '''
        if l == r:
            self.tree[tree_index] = self.data[l]
            return
        mid = (l+r) // 2 # 区间中点,对应左孩子区间结束,右孩子区间开头
        left, right = 2 * tree_index + 1, 2 * tree_index + 2 # tree_index的左右子树索引
        self._build(left, l, mid)
        self._build(right, mid+1, r)
        self.tree[tree_index] = self._merge(self.tree[left], self.tree[right])

    def _query(self, tree_index, l, r, ql, qr):
        '''
        递归查询区间[ql,..,qr]的值
        tree_index : 某个根节点的索引
        l, r : 该节点表示的区间的左右边界
        ql, qr: 待查询区间的左右边界
        '''
        if l == ql and r == qr:
            return self.tree[tree_index]

        mid = (l+r) // 2 # 区间中点,对应左孩子区间结束,右孩子区间开头
        left, right = tree_index * 2 + 1, tree_index * 2 + 2
        if qr <= mid:
            # 查询区间全在左子树
            return self._query(left, l, mid, ql, qr)
        elif ql > mid:
            # 查询区间全在右子树
            return self._query(right, mid+1, r, ql, qr)

        # 查询区间一部分在左子树一部分在右子树
        return self._merge(self._query(left, l, mid, ql, mid), 
                          self._query(right, mid+1, r, mid+1, qr))

    def _update(self, tree_index, l, r, index):
        '''
        tree_index:某个根节点索引
        l, r : 此根节点代表区间的左右边界
        index : 更新的值的索引
        '''
        if l == r == index:
            self.tree[tree_index] = self.data[index]
            return
        mid = (l+r)//2
        left, right = 2 * tree_index + 1, 2 * tree_index + 2
        if index > mid:
            # 要更新的区间在右子树
            self._update(right, mid+1, r, index)
        else:
            # 要更新的区间在左子树index<=mid
            self._update(left, l, mid, index)
        # 里面的小区间变化了,包裹的大区间也要更新
        self.tree[tree_index] = self._merge(self.tree[left], self.tree[right])

结语

线段树作为一种高级数据结构
现在简单记一笔,回头再看看

算法有兴趣的可以来看看 在自然数,且所有的数不大于30000的范围内讨论一个问题:现在已知n条线段,把端点依次输入告诉你,然后有m个询问,每个询问输入一个点,要求这个点在多少条线段上出现过; 最基本的解法当然就是读一个点,就把所有线段比一下,看看在不在线段中; 每次询问都要把n条线段查一次,那么m次询问,就要运算m*n次,复杂度就是O(m*n) 这道题m和n都是30000,那么计算量达到了10^9;而计算机1秒的计算量大约是10^8的数量级,所以这种方法无论怎么优化都是超时 因为n条线段是固定的,所以某种程度上说每次都把n条线段查一遍有大量的重复和浪费; 线段树就是可以解决这类问题的数据结构 举例说明:已知线段[2,5] [4,6] [0,7];求点2,4,7分别出现了多少次 在[0,7]区间上建立一棵满二叉树:(为了和已知线段区别,用【】表示线段树中的线段) 【0,7】 / \ 【0,3】 【4,7】 / \ / \ 【0,1】 【2,3】 【4,5】 【6,7】 / \ / \ / \ / \ 【0,0】 【1,1】 【2,2】 【3,3】 【4,4】 【5,5】 【6,6】 【7,7】 每个节点用结构体: struct line { int left,right; // 左端点、右端点 int n; // 记录这条线段出现了多少次,默认为0 }a[16]; 和堆类似,满二叉树的性质决定a[i]的左儿子是a[2*i]、右儿子是a[2*i+1]; 然后对于已知的线段依次进行插入操作: 从树根开始调用递归函数insert // 要插入的线段的左端点和右端点、以及当前线段树中的某条线段 void insert(int s,int t,int step)
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值