Segment Tree

模板

懒标记/朴素线段树

refs:https://www.bilibili.com/video/BV1G34y1L7b3/?spm_id_from=333.999.0.0(董晓算法)

模板题:【模板】线段树 1 - 洛谷   and 【模板】树状数组 1 - 洛谷

LuoGu语言歧视还卡常,无语,两道题其实都是写对的,但是他硬卡到了70/100。T我三个点。

from typing import List

# 带懒标记
class node:
    def __init__(self,l:int,r:int,sum:int,add:int) -> None:
        self.l = l
        self.r = r
        self.sum = sum
        self.add = add

# 下标从1开始
class SegTree:

    def lc(self,p:int)->int:
        return p<<1
    
    def rc(self,p:int)->int:
        return (p<<1)|1

    # 懒标记向下传递
    def pushdown(self,curr:int):
        if self.tr[curr].add:
            self.tr[self.lc(curr)].sum += (self.tr[self.lc(curr)].r-self.tr[self.lc(curr)].l+1)*self.tr[curr].add
            self.tr[self.rc(curr)].sum += (self.tr[self.rc(curr)].r-self.tr[self.rc(curr)].l+1)*self.tr[curr].add
            self.tr[self.lc(curr)].add += self.tr[curr].add
            self.tr[self.rc(curr)].add += self.tr[curr].add
            self.tr[curr].add= 0

    # 懒标记回溯
    def pushup(self,curr:int):
        self.tr[curr].sum = self.tr[self.lc(curr)].sum + self.tr[self.rc(curr)].sum
    

    def __init__(self,w:List[int]) -> None:
        w = [0]+w
        n = len(w)

        self.tr = [node(-1,-1,-1,-1) for _ in range(4*n)]

        def build(curr:int,l:int,r:int):
            self.tr[curr] = node(l,r,w[l],0)
            if l==r:
                return
            mid = (l+r)>>1
            build(self.lc(curr),l,mid)
            build(self.rc(curr),mid+1,r)
            self.tr[curr].sum = self.tr[self.lc(curr)].sum + self.tr[self.rc(curr)].sum
        
        build(1,1,n-1)

    def update_pt(self,curr:int,x:int,diff:int):
        if self.tr[curr].l == self.tr[curr].r == x:
            self.tr[curr].sum += diff
            return
        
        mid = (self.tr[curr].l+self.tr[curr].r)>>1

        if x<=mid:
            self.update_pt(self.lc(curr),x,diff)
        else:
            self.update_pt(self.rc(curr),x,diff)
        self.pushup(curr)

    def update(self,curr:int,x:int,y:int,diff:int):
        if x<=self.tr[curr].l and self.tr[curr].r<=y:
            self.tr[curr].sum += (self.tr[curr].r-self.tr[curr].l+1)*diff
            self.tr[curr].add += diff
            return
        mid = (self.tr[curr].l + self.tr[curr].r)>>1
        
        self.pushdown(curr)

        if x<=mid:
            self.update(self.lc(curr),x,y,diff)
        if y>mid:
            self.update(self.rc(curr),x,y,diff)

        self.pushup(curr)
    
    def query(self,curr:int,x:int,y:int)->int:
        if x<=self.tr[curr].l and self.tr[curr].r<=y:
            return self.tr[curr].sum
        
        mid = (self.tr[curr].l+self.tr[curr].r)>>1
        self.pushdown(curr)
        res = 0
        if x<=mid:
            res += self.query(self.lc(curr),x,y)
        if y>mid:
            res += self.query(self.rc(curr),x,y)
        return res

下标

从1开始,接收一个从0开始的数组。然后在前面拼接一个dummy元素即可将下标从1开始。build时记得n-1。

区间/点修改

  1. 点修改:用不上懒标记
  2. 区间修改:用懒标记。但不用担心和点修改导致冲突。点修改本身不会产生任何非叶子的add增量,所以这两个修改方式和懒标记都是兼容的。

这里简单说下懒标记为什么回溯。主要我一开始以为:

    def update(self,curr:int,x:int,y:int,diff:int):
        if x<=self.tr[curr].l and self.tr[curr].r<=y:
            self.tr[curr].sum += (self.tr[curr].r-self.tr[curr].l+1)*diff
            self.tr[curr].add += diff
            return

这一段已经改过父节点了,为什么下面还pushup,这不多此一举吗?后来我发现这个覆盖不是对任何情况生效的。举个例子:

在上图中我们修改区间[4,9],显然一开始会从线段[1,10]分裂。而分裂时不走if。因此[1,10]这个点的sum如果不pushup就不会更新。所以还是要回溯的。

区间覆盖参数

query和update中都有待查询/修改区间的左右边界x,y。可能会疑惑为什么递归时传参一直写死x,y,而不是[x,mid]和[mid+1,y]。(比如查询时)

这个写法在左半部分是没问题的,即[x,mid],因为这个区间是肯定要查的。问题出在右半部分,举个例子,假设我们要查[5,9]:

  1. mid[1,10]=5,分裂为[5,5]和[6,9]
  2. mid[1,5]=3,这个时候[mid+1,y] = [4,5],给下一层传参的x=4,y=5,这样if检查发现[4,5]正好和[x,y]重叠,那么就会多查一个[4,4]的区间。

换句话说,如果我们的x过大,甚至比mid+1还大,这样写查询范围就会多查了。

那么接下来就可能会想了:如果写成

query(self.rc(curr),min(y,max(mid+1,x)),y)

行不行?

  1. 首先,max(mid+1,x)保证了左边界不会爆掉。
  2. 其次,min(max(mid+1,x),y)保证了左边界不会把右边界爆掉。

其实还是不行。从LuoGu上下了数据,发现这一组会WA:

8 10
640 591 141 307 942 58 775 133 
2 1 5
2 3 8
2 3 6
2 5 8
2 4 8
1 4 8 60
2 1 6
2 5 8
1 3 7 15
1 2 6 86

其中第一个操作(2,1,5)就错了,答案是2621,我的程序跑出来2679。正好多了第六个元素58。

  1. 因为根节点左子树跑满了,所以左子树肯定不用看了,对的。
  2. [1,8] mid4,y=5,右子树要跑,x=5,y=5。
  3. [5,8] mid6,y=5,左子树要跑,x=5,y=6

这下就多跑了个w[6]了。

因此我们只要改成:

        if x<=mid:
            res += self.query(self.lc(curr),x,max(x,min(y,mid)))
        if y>mid:
            res += self.query(self.rc(curr),min(y,max(mid+1,x)),y)

即可。

但这么写也太麻烦了,还不如一直维持区间覆盖的参数不变呢。也就是(x,y)。

动态开点线段树

有时候整个区间特别长,比如长度可能是1e9、1e10这种。然后并不是每个单位区间都有值(或者说被赋值/修改),这个时候再开静态的4N的数组可能会MLE。所以就有了动态开点线段树。

在OI-Wiki中,对动态开点线段树核心思想的描述是:

总之,动态开点线段树的核心思想就是:结点只有在有需要的时候才被创建

即pure demanding。

class Node:
    def __init__(self,left,right,l:int,r:int,sum:int,add:int) -> None:
        self.left = left
        self.right = right
        self.l = l
        self.r = r
        self.sum = sum
        self.add = add
    
    def __str__(self)->str:
        return f'range:[{self.l},{self.r}],sum:{self.sum}'
        

class DynamicSegTree:
    def __init__(self,lx:int,ry:int) -> None:
        self.root = Node(None,None,lx,ry,0,0)
    
    def pushdown(self,node:Node):
        mid = (node.l+node.r)>>1
        if node.left is None:
            node.left = Node(None,None,node.l,mid,0,0)
        if node.right is None:
            node.right = Node(None,None,mid+1,node.r,0,0)
        if node.add:
            node.left.sum += (node.left.r-node.left.l+1)*node.add
            node.right.sum += (node.right.r-node.right.l+1)*node.add
            node.left.add += node.add
            node.right.add += node.add
            node.add = 0
    
    def pushup(self,node:Node):
        node.sum = node.left.sum + node.right.sum
    
    def update(self,l:int,r:int,diff:int,node:Node):
        if l<=node.l and node.r<=r:
            node.sum += (node.r-node.l+1)*diff
            node.add += diff
            return
        
        self.pushdown(node)
        mid = (node.l+node.r)>>1
        if l<=mid:
            self.update(l,r,diff,node.left)
        if r>mid:
            self.update(l,r,diff,node.right)
        self.pushup(node)

    def query(self,l:int,r:int,node:Node)->int:
        if l<=node.l and node.r<=r:
            return node.sum
        
        self.pushdown(node)
        
        res = 0
        mid = (node.l+node.r)>>1
        if l<=mid:
            res += self.query(l,r,node.left)
        if r>mid:
            res += self.query(l,r,node.right)
        return res

新增左右子树的时机

在pushdown中,有这么一段:

    def pushdown(self,node:Node):
        mid = (node.l+node.r)>>1
        if node.left is None:
            node.left = Node(None,None,node.l,mid,0,0)
        if node.right is None:
            node.right = Node(None,None,mid+1,node.r,0,0)
        if node.add:
	        # ......

这里的问题是,为什么不把更新左右子树的时机再拖一下,放到if node.add里面去?

这是因为下面紧接着就有:

        mid = (node.l+node.r)>>1
        if l<=mid:
            self.update(l,r,diff,node.left)
        if r>mid:
            self.update(l,r,diff,node.right)

如果之前没有懒标记就不新增左右子树,这里就会访问到None。举个例子,假设当前节点[1,5],更新[2,7]:

  1. mid = 3,分裂为[2,3],[4,5]。且此时当前节点add=0
  2. 1≤3,进左子树更新,由于pushdown没更新,那么node.left=None。完蛋。

所以只要会分裂,无论进哪个子树,直接新增就行了。

1. LC 699 掉落的方块

这题1e8+1e6的数据范围,开4N就MLE了。所以板子用动态开点的线段树。

比较重要的是这个线段树板子怎么改。它的意思不是给区间里每个数上加上一个数,而是把区间里所有数改成一个数。所以这里应该是重新赋值,而不是累加。

    def update(self,l:int,r:int,diff:int,node:Node):
        if l<=node.l and node.r<=r:
            node.sum = diff
            node.add = diff
            return
            
    def pushdown(self,node:Node):
        mid = (node.l+node.r)>>1
        if node.left is None:
            node.left = Node(None,None,node.l,mid,0,0)
        if node.right is None:
            node.right = Node(None,None,mid+1,node.r,0,0)
        if node.add:
            node.left.sum = node.add
            node.right.sum = node.add
            node.left.add = node.right.add = node.add
            node.add = 0

对于回溯(pushup)我们可以发现,当一个新的方块落下时,他应该查找[l,r]中最大的高度并且累加上去。所以回溯时,我们取左右子树的较大值作为父节点的sum。

    def pushup(self,node:Node):
        node.sum = max(node.left.sum,node.right.sum)

 

以上图为例,假设方块是:

[2,1]

[2,9]

[1,8]

这里另外要提的一点是,因为贴边不能叠,所以我们直接对右边界减一。这样就不会重复累加贴边的情况了。例如第一个方块[x,y],那么[x:x+y-1],第二个方块[z,x-z],则[z:x-1],但实际上这俩是贴边的。

所以:

  1. 首先我们先修改[2,2]的值为1
  2. 然后查询[2,10]区间的最大值,发现是[2,2]的1,给他加上9变成10,更新[2,10]区间的所有元素为10。

这个时候[1,2]的最大值很显然为10(例如我们要在[1,2]上放一个方块,他就得从10开始算了),但[1,1]的值为0,[2,2]的值为10,因此取左右子树最大值,也即10。

具体实现就是套板子加上l+s-1的贴边处理:

from typing import List

class Node:
    def __init__(self,left,right,l:int,r:int,sum:int,add:int) -> None:
        self.left = left
        self.right = right
        self.l = l
        self.r = r
        self.sum = sum
        self.add = add
        

class DynamicSegTree:
    def __init__(self,lx:int,ry:int) -> None:
        self.root = Node(None,None,lx,ry,0,0)
    
    def pushdown(self,node:Node):
        mid = (node.l+node.r)>>1
        if node.left is None:
            node.left = Node(None,None,node.l,mid,0,0)
        if node.right is None:
            node.right = Node(None,None,mid+1,node.r,0,0)
        if node.add:
            node.left.sum = node.add
            node.right.sum = node.add
            node.left.add = node.right.add = node.add
            node.add = 0
    
    def pushup(self,node:Node):
        node.sum = max(node.left.sum,node.right.sum)
    
    def update(self,l:int,r:int,diff:int,node:Node):
        if l<=node.l and node.r<=r:
            node.sum = diff
            node.add = diff
            return
        
        self.pushdown(node)
        mid = (node.l+node.r)>>1
        if l<=mid:
            self.update(l,r,diff,node.left)
        if r>mid:
            self.update(l,r,diff,node.right)
        self.pushup(node)

    def query(self,l:int,r:int,node:Node)->int:
        if l<=node.l and node.r<=r:
            return node.sum
        
        self.pushdown(node)
        
        res = 0
        mid = (node.l+node.r)>>1
        if l<=mid:
            res = max(res,self.query(l,r,node.left))
        if r>mid:
            res = max(res,self.query(l,r,node.right))
        return res

class Solution:
    def fallingSquares(self, positions: List[List[int]]) -> List[int]:
        mx = max(l+s for l,s in positions)+1
        st = DynamicSegTree(1,mx)

        ans = []
        tmp = 0
        for l,s in positions:
            r = l+s-1
            h = st.query(l,r,st.root)+s
            tmp = max(h,tmp)
            ans.append(tmp)
            st.update(l,r,h,st.root)
        return ans

  • 12
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值