[python刷题模板] 线段树

一、 算法&数据结构

1. 描述

线段树通常用来维护区间询问,通过二分的形式对数组分块,维护一个个小块上的属性(极值、求和等),用以实现O(logn)时间的查询和更新

2. 复杂度分析

  1. 查询query, O(log2n)
  2. 更新update,O(log2n)

3. 常见应用

  1. 单点更新,区间求极值(最入门)
  2. 单点更新,区间求和(稍复杂)
  3. 区间更新,单点或区间求值,如果卡常数需要用到lazytag

4. 常用优化

  1. 设置lazytag,用于区间更新,判断全包含时,不再向下递归,一般卡常数可以搞,每次update和query都需要give_lazy_to_son
  2. 离散化,因为线段树维护的是整数,如果题目给的是实数(浮点、复数、过大的数),那么可以把数据离散化,毕竟数组长度一般不会太大。

二、 模板代码

1. 区间更新,区间询问最大值(IUIQ)

例题: 699. 掉落的方块
用线段树维护x轴上每个线段(区间)的最大高度

class IntervalTree:
    def __init__(self, size):
        self.size = size
        self.interval_tree = [0 for _ in range(size*4)]
        self.lazys = [0 for _ in range(size*4)]

    def give_lay_to_son(self,p,l,r):
        interval_tree = self.interval_tree
        lazys = self.lazys
        if lazys[p] == 0:
            return
        mid = (l+r)//2
        interval_tree[p*2] = lazys[p]
        interval_tree[p*2+1] = lazys[p]
        lazys[p*2] = lazys[p]
        lazys[p*2+1] = lazys[p]
        lazys[p] = 0
        
    def update(self,p,l,r,x,y,val):
        """
        把[x,y]区域全变成val
        """
        if y < l or r < x:
            return 
        interval_tree = self.interval_tree    
        lazys = self.lazys        
        if x <= l and r<=y:
            interval_tree[p] = val
            lazys[p] = val
            return
        self.give_lay_to_son(p,l,r)
        mid = (l+r)//2
        if x <= mid:
            self.update(p*2,l,mid,x,y,val)
        if mid < y:
            self.update(p*2+1,mid+1,r,x,y,val)
        interval_tree[p] = max(interval_tree[p*2], interval_tree[p*2+1])
    
    def query(self,p,l,r,x,y):
        """
        查找x,y区间的最大值        """        
        
        if y < l or r < x:
            return 0
        if x<=l and r<=y:
            return self.interval_tree[p]
        self.give_lay_to_son(p,l,r)
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s = max(s,self.query(p*2,l,mid,x,y))
        if mid < y:
            s = max(s,self.query(p*2+1,mid+1,r,x,y))
        return s

class Solution:
    def fallingSquares(self, positions: List[List[int]]) -> List[int]:
        n = len(positions)
        hashes = [left for left,_ in positions] + [left+side for left,side in positions] 
        hashes = sorted(list(set(hashes)))
        # 用线段树维护x轴区间最大值,记录每个点的高度:比如[1,2]这个方块,会使线段[1,2]闭区间这个线段上的每个高度都变成2
        # 落下一个新方块时,查询它的底边所在线段[x,y]的最大高度h,这个方块会落在这个高度h,把新高度h+side插入线段树[x,y]的部分
        # 每次插入结束,树根存的高度就是当前最大高度
        # 由于数据范围大1 <= lefti <= 108,需要散列化
        # 散列化的值有left和right(线段短点)
        # print(hashes)
        tree_size = len(hashes)
        tree = IntervalTree(tree_size)
        heights = []
        for left,d in positions:
            right = left + d 
            l = bisect_left(hashes,left)
            r = bisect_left(hashes,right)
            h = tree.query(1,1,tree_size,l+1,r)
            tree.update(1,1,tree_size,l+1,r,h+d)
            heights.append(tree.interval_tree[1])
        return heights

2. 矩形面积并

链接: 850. 矩形面积 II

线段树经典案例,涉及离散化扫描线,细节很多,非常难写

class IntervalTreeNode:
    def __init__(self, len,cover):
        self.len = len
        self.cover = cover
class IntervalTree:
    def __init__(self, size,ys=None):
        self.size = size
        # self.interval_tree = [IntervalTreeNode(0,0)]*(size*4)     ## 这个地方wa了很久,不能这么写,一直更新一个实例
        self.interval_tree = [IntervalTreeNode(0,0) for _ in range(size*4)]   
        self.ys=ys

    def update_from_son(self,p,l,r):
        interval_tree = self.interval_tree
        pn = interval_tree[p]
        if pn.cover > 0:
            pn.len = self.ys[r-1]-self.ys[l-1]
        else:
            if l+1 ==r:
                pn.len = 0
            else:
                pn.len = interval_tree[p*2].len + interval_tree[p*2+1].len 

    def insert(self,p,l,r,x,y,cover):        
        if y < l or r < x:
            return
        interval_tree = self.interval_tree        
        if x<=l and r<=y:
            interval_tree[p].cover += cover
            self.update_from_son(p,l,r)
            return
        mid = (l+r)//2
        if x < mid:
            self.insert(p*2,l,mid,x,y,cover)
        if y > mid:
            self.insert(p*2+1,mid,r,x,y,cover)
        self.update_from_son(p,l,r)
            
class LineY:
    def __init__(self,x,y1,y2,cover):
        self.x = x
        self.y1 = y1
        self.y2 = y2
        self.cover = cover
        
class Solution:
    def rectangleArea(self, rectangles: List[List[int]]) -> int:
        lines = []  # 所有竖线线段
        ys = set()  # 离散化
        for x1,y1,x2,y2 in rectangles:
            lines.append(LineY(x1,y1,y2,1))
            lines.append(LineY(x2,y1,y2,-1))
            ys.add(y1)
            ys.add(y2)
        lines.sort(key=lambda x:x.x)
        line_count = len(lines)
        ys = list(ys)
        ys.sort()

        interval_tree = IntervalTree(line_count,ys=ys)

        ans = 0
        mod = int(1e9+7)
        for i in range(0,line_count):
            line = lines[i]
            # print(line.x,line.y1,line.y2,line.cover)
            y1 = bisect_left(ys,line.y1) 
            y2 = bisect_left(ys,line.y2)
            # if y1==y2:
            #     continue
            if i >0:
                ans += (line.x-lines[i-1].x) * interval_tree.interval_tree[1].len
                ans %= mod
            interval_tree.insert(1,1,len(ys),y1+1,y2+1,line.cover)
        
        return ans

3.单点更新,区间求和

链接: 307. 区域和检索 - 数组可修改

线段树经典案例,比区间求极值麻烦一点点

class IntervalTree:
    def __init__(self, size,nums=None):
        self.size = size
        self.nums = nums
        self.interval_tree = [0]*(size*4)
        if nums:
            self.build_tree(1,1,size)

    def build_tree(self,p,l,r):
        interval_tree = self.interval_tree
        nums = self.nums
        if l == r:
            interval_tree[p] = nums[l-1]
            return
        mid = (l+r)//2
        self.build_tree(p*2,l,mid)
        self.build_tree(p*2+1,mid+1,r)
        interval_tree[p] = interval_tree[p*2]+interval_tree[p*2+1]
    
    def add_point(self,p,l,r,index,add):        
        if index < l or r < index:
            return 
        interval_tree = self.interval_tree
        interval_tree[p] += add
        if l == r:
            return
        mid = (l+r)//2
        if index <= mid:
            self.add_point(p*2,l,mid,index,add)
        else:
            self.add_point(p*2+1,mid+1,r,index,add)
    
    def sum_interval(self,p,l,r,x,y):        
        if y < l or r < x:
            return 0
        interval_tree = self.interval_tree
        if x<=l and r<=y:
            return interval_tree[p]
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s += self.sum_interval(p*2,l,mid,x,y)
        if mid < y:
            s += self.sum_interval(p*2+1,mid+1,r,x,y)
        return s

class NumArray:

    def __init__(self, nums: List[int]):
        self.size = len(nums)
        self.nums = nums
        self.interval_tree = IntervalTree(self.size ,nums)
        
    def update(self, index: int, val: int) -> None:         
        add = val - self.nums[index]
        self.nums[index] = val
        self.interval_tree.add_point(1,1,self.size,index+1,add)
        
    def sumRange(self, left: int, right: int) -> int:        
        return self.interval_tree.sum_interval(1,1,self.size,left+1,right+1)         

4.单点更新,区间求和

链接: 327. 区间和的个数

线段树这题麻烦一点,求区间内数字数量,每个数初始化为0,插入时候+1,计数转化为求和
参考链接: [LeetCode解题报告]327. 区间和的个数

class IntervalTree:
    def __init__(self, size):
        self.size = size
        self.interval_tree = [0 for _ in range(size*4)]

    def insert(self,p,l,r,index):
        if index < l or r < index:
            return 
        interval_tree = self.interval_tree        
        if l == r:
            interval_tree[p] += 1
            return
        mid = (l+r)//2
        if index <= mid:
            self.insert(p*2,l,mid,index)
        else:
            self.insert(p*2+1,mid+1,r,index)
        interval_tree[p] = interval_tree[p*2]+interval_tree[p*2+1]       
    
    def query(self,p,l,r,x,y):
        if y < l or r < x:
            return 0
        if x<=l and r<=y:
            return self.interval_tree[p]
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s += self.query(p*2,l,mid,x,y)
        if mid < y:
            s += self.query(p*2+1,mid+1,r,x,y)
        return s

class Solution:
    def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:         
        s = list(accumulate(nums,initial=0))
        hashes = s + [ x-lower for x in s] + [ x-upper for x in s]
        hashes = sorted(list(set(hashes)))
        # 生成前缀和,问题转化为,对于每个j,找左边的i,判断 s[j]-upper<=s[i]<=s[j]-lower,统计这些i的数量
        # 把所有前缀和数组中的数字插入线段树,并对这些数字划分区间,线段树维护当前区间数字数量,
        # 所以需要对这些数字都散列化
        tree_size = len(hashes)
        tree = IntervalTree(tree_size)
        cnt = 0
        for i in s:
            x = bisect_left(hashes,i-upper)
            y = bisect_left(hashes,i-lower)
            j = bisect_left(hashes,i)
            c = tree.query(1,1,tree_size, x+1,y+1)
            # print(x,y,j,c)
            cnt += c
            tree.insert(1,1,tree_size,j+1)

        return cnt

5.区间更新,区间查询,无法离散化,动态开点。

链接: 3732. 我的日程安排表 III

线段树这题由于只能在线做,不能做离散化,因此需要用字典维护线段端点,实现动态开点。
参考链接: [LeetCode解题报告] 732. 我的日程安排表 III

class IntervalTree:
    def __init__(self):
        self.interval_tree = collections.defaultdict(int)
        self.lazys = collections.defaultdict(int)        

    def give_lay_to_son(self,p,l,r):
        interval_tree = self.interval_tree
        lazys = self.lazys
        if lazys[p] == 0:
            return
        mid = (l+r)//2
        interval_tree[p*2] += lazys[p]
        interval_tree[p*2+1] += lazys[p]
        lazys[p*2] += lazys[p]
        lazys[p*2+1] += lazys[p]
        lazys[p] = 0
        
    def add(self,p,l,r,x,y,val):
        """
        把[x,y]区域全+val
        """
        if r < x or y < l:  # 这里不加就会TLE
            return 
        interval_tree = self.interval_tree    
        lazys = self.lazys        
        if x <= l and r<=y:
            interval_tree[p] += val
            lazys[p] += val
            return
        self.give_lay_to_son(p,l,r)  #这题由于永远不会询问子区间,所以其实可以不向下give,直接在return的时候+lazy,会快一点。
        mid = (l+r)//2
        if x <= mid:
            self.add(p*2,l,mid,x,y,val)
        if mid < y:
            self.add(p*2+1,mid+1,r,x,y,val)
        interval_tree[p] = max(interval_tree[p*2], interval_tree[p*2+1]) 
    
    def query(self,p,l,r,x,y):
        """
        查找x,y区间的最大值
        """        
        if x<=l and r<=y:
            return self.interval_tree[p]
        self.give_lay_to_son(p,l,r)
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s = max(s,self.query(p*2,l,mid,x,y))
        if mid < y:
            s = max(s,self.query(p*2+1,mid+1,r,x,y))
        return s

class MyCalendarThree:

    def __init__(self):
        self.tree = IntervalTree()
   
    def book(self, start: int, end: int) -> int:
        self.tree.add(1,1,10**9+1,start,end-1,1)
        return self.tree.interval_tree[1]

6.单点更新,区间查询最大值。

链接: 6206. 最长递增子序列 II

线段树这题算是打开了LIS一个新的优化思路,传统N方DP由于单次查询是On因此是N方,用线段树可以把单次查询降低到lg
参考链接: [LeetCode周赛复盘] 第 310 场周赛20220911

class IntervalTree:
    def __init__(self, size,nums=None):
        self.size = size
        self.nums = nums
        self.interval_tree = [0]*(size*4)

    def update_point(self,p,l,r,index,val):        
        if index < l or r < index:
            return 
        interval_tree = self.interval_tree
        interval_tree[p] =max(interval_tree[p],val)
        if l == r:
            return
        mid = (l+r)//2
        if index <= mid:
            self.update_point(p*2,l,mid,index,val)
        else:
            self.update_point(p*2+1,mid+1,r,index,val)
    
    def query(self,p,l,r,x,y):
        """
        查找x,y区间的最大值        """        
        
        if y < l or r < x:
            return 0
        if x<=l and r<=y:
            return self.interval_tree[p]
        
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s = max(s,self.query(p*2,l,mid,x,y))
        if mid < y:
            s = max(s,self.query(p*2+1,mid+1,r,x,y))
        return s

    
class Solution:
    def lengthOfLIS(self, nums: List[int], k: int) -> int:
        n = len(nums)
        mx = max(nums)
        tree = IntervalTree(mx)
        ans = 0
        for i in range(n):
            v = nums[i]
            l = max(0,v-k)
            r = max(0,v-1)
            ret = tree.query(1,1,mx,l,r)+1
            tree.update_point(1,1,mx,v,ret)
            ans = max(ans,ret)

        return ans

7. 区间01翻转(异或),区间查询1的个数。

链接: 6358. 更新数组后处理求和查询
链接: P3870 [TJOI2009] 开关

  • 注意处理时,lazy为1才需要向下处理。
  • lazy^=1;然后重新计算区间1的个数,其实就是取反:长度-原个数。
class IntervalTree:
    def __init__(self, size):
        self.size = size
        self.interval_tree = [0 for _ in range(size*4)]
        self.lazys = [0 for _ in range(size*4)]

    def give_lay_to_son(self,p,l,r):
        interval_tree = self.interval_tree
        lazys = self.lazys
        if lazys[p] == 0:
            return
        mid = (l+r)//2
        interval_tree[p*2] = mid - l + 1 -  interval_tree[p*2]
        interval_tree[p*2+1] = r - mid - interval_tree[p*2+1]
        lazys[p*2] ^= 1
        lazys[p*2+1] ^=1
        lazys[p] = 0
        
    def update(self,p,l,r,x,y,val):
        """
        把[x,y]区域全变成val
        """
        if y < l or r < x:
            return 
        interval_tree = self.interval_tree    
        lazys = self.lazys        
        if x <= l and r<=y:
            interval_tree[p] = r-l+1-interval_tree[p]
            lazys[p] ^= 1
            return
        self.give_lay_to_son(p,l,r)
        mid = (l+r)//2
        if x <= mid:
            self.update(p*2,l,mid,x,y,val)
        if mid < y:
            self.update(p*2+1,mid+1,r,x,y,val)
        interval_tree[p] = interval_tree[p*2]+ interval_tree[p*2+1]    

    
    def query(self,p,l,r,x,y):
        """
        区间求和      """        
        
        if y < l or r < x:
            return 0
        if x<=l and r<=y:
            return self.interval_tree[p]
        self.give_lay_to_son(p,l,r)
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s += self.query(p*2,l,mid,x,y)
        if mid < y:
            s += self.query(p*2+1,mid+1,r,x,y)
        return s
    
class Solution:
    def handleQuery(self, nums1: List[int], nums2: List[int], queries: List[List[int]]) -> List[int]:
        n = len(nums1)
        s = sum(nums2)
        tree = IntervalTree(n)
        for i,v in enumerate(nums1,start=1):
            if v:
                tree.update(1,1,n,i,i,1)
        ans = []
        for op,l,r in queries:
            if op == 1:
                tree.update(1,1,n,l+1,r+1,1)
            elif op == 2:
                s += l*tree.query(1,1,n,1,n)
            else:
                ans.append(s)
        return ans

三、其他

  1. 如果还是卡常数,有些区间问题可以转化为树状数组,常数小,代码短,不过真的很难理解,还是线段树好写。遇到就套板吧。

四、更多例题

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
算法有兴趣的可以来看看 在自然数,且所有的数不大于30000的范围内讨论一个问题:现在已知n条线段,把端点依次输入告诉你,然后有m个询问,每个询问输入一个点,要求这个点在多少条线段上出现过; 最基本的解法当然就是读一个点,就把所有线段比一下,看看在不在线段中; 每次询问都要把n条线段查一次,那么m次询问,就要运算m*n次,复杂度就是O(m*n) 这道题m和n都是30000,那么计算量达到了10^9;而计算机1秒的计算量大约是10^8的数量级,所以这种方法无论怎么优化都是超时 因为n条线段是固定的,所以某种程度上说每次都把n条线段查一遍有大量的重复和浪费; 线段树就是可以解决这类问题的数据结构 举例说明:已知线段[2,5] [4,6] [0,7];求点2,4,7分别出现了多少次 在[0,7]区间上建立一棵满二叉树:(为了和已知线段区别,用【】表示线段树中的线段) 【0,7】 / \ 【0,3】 【4,7】 / \ / \ 【0,1】 【2,3】 【4,5】 【6,7】 / \ / \ / \ / \ 【0,0】 【1,1】 【2,2】 【3,3】 【4,4】 【5,5】 【6,6】 【7,7】 每个节点用结构体: struct line { int left,right; // 左端点、右端点 int n; // 记录这条线段出现了多少次,默认为0 }a[16]; 和堆类似,满二叉树的性质决定a[i]的左儿子是a[2*i]、右儿子是a[2*i+1]; 然后对于已知的线段依次进行插入操作: 从树根开始调用递归函数insert // 要插入的线段的左端点和右端点、以及当前线段树中的某条线段 void insert(int s,int t,int step)

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值