模板
懒标记/朴素线段树
refs:https://www.bilibili.com/video/BV1G34y1L7b3/?spm_id_from=333.999.0.0(董晓算法)
模板题:【模板】线段树 1 - 洛谷 and 【模板】树状数组 1 - 洛谷
LuoGu语言歧视还卡常,无语,两道题其实都是写对的,但是他硬卡到了70/100。T我三个点。
from typing import List
# 带懒标记
class node:
def __init__(self,l:int,r:int,sum:int,add:int) -> None:
self.l = l
self.r = r
self.sum = sum
self.add = add
# 下标从1开始
class SegTree:
def lc(self,p:int)->int:
return p<<1
def rc(self,p:int)->int:
return (p<<1)|1
# 懒标记向下传递
def pushdown(self,curr:int):
if self.tr[curr].add:
self.tr[self.lc(curr)].sum += (self.tr[self.lc(curr)].r-self.tr[self.lc(curr)].l+1)*self.tr[curr].add
self.tr[self.rc(curr)].sum += (self.tr[self.rc(curr)].r-self.tr[self.rc(curr)].l+1)*self.tr[curr].add
self.tr[self.lc(curr)].add += self.tr[curr].add
self.tr[self.rc(curr)].add += self.tr[curr].add
self.tr[curr].add= 0
# 懒标记回溯
def pushup(self,curr:int):
self.tr[curr].sum = self.tr[self.lc(curr)].sum + self.tr[self.rc(curr)].sum
def __init__(self,w:List[int]) -> None:
w = [0]+w
n = len(w)
self.tr = [node(-1,-1,-1,-1) for _ in range(4*n)]
def build(curr:int,l:int,r:int):
self.tr[curr] = node(l,r,w[l],0)
if l==r:
return
mid = (l+r)>>1
build(self.lc(curr),l,mid)
build(self.rc(curr),mid+1,r)
self.tr[curr].sum = self.tr[self.lc(curr)].sum + self.tr[self.rc(curr)].sum
build(1,1,n-1)
def update_pt(self,curr:int,x:int,diff:int):
if self.tr[curr].l == self.tr[curr].r == x:
self.tr[curr].sum += diff
return
mid = (self.tr[curr].l+self.tr[curr].r)>>1
if x<=mid:
self.update_pt(self.lc(curr),x,diff)
else:
self.update_pt(self.rc(curr),x,diff)
self.pushup(curr)
def update(self,curr:int,x:int,y:int,diff:int):
if x<=self.tr[curr].l and self.tr[curr].r<=y:
self.tr[curr].sum += (self.tr[curr].r-self.tr[curr].l+1)*diff
self.tr[curr].add += diff
return
mid = (self.tr[curr].l + self.tr[curr].r)>>1
self.pushdown(curr)
if x<=mid:
self.update(self.lc(curr),x,y,diff)
if y>mid:
self.update(self.rc(curr),x,y,diff)
self.pushup(curr)
def query(self,curr:int,x:int,y:int)->int:
if x<=self.tr[curr].l and self.tr[curr].r<=y:
return self.tr[curr].sum
mid = (self.tr[curr].l+self.tr[curr].r)>>1
self.pushdown(curr)
res = 0
if x<=mid:
res += self.query(self.lc(curr),x,y)
if y>mid:
res += self.query(self.rc(curr),x,y)
return res
下标
从1开始,接收一个从0开始的数组。然后在前面拼接一个dummy元素即可将下标从1开始。build时记得n-1。
区间/点修改
- 点修改:用不上懒标记
- 区间修改:用懒标记。但不用担心和点修改导致冲突。点修改本身不会产生任何非叶子的add增量,所以这两个修改方式和懒标记都是兼容的。
这里简单说下懒标记为什么回溯。主要我一开始以为:
def update(self,curr:int,x:int,y:int,diff:int):
if x<=self.tr[curr].l and self.tr[curr].r<=y:
self.tr[curr].sum += (self.tr[curr].r-self.tr[curr].l+1)*diff
self.tr[curr].add += diff
return
这一段已经改过父节点了,为什么下面还pushup,这不多此一举吗?后来我发现这个覆盖不是对任何情况生效的。举个例子:
在上图中我们修改区间[4,9],显然一开始会从线段[1,10]分裂。而分裂时不走if。因此[1,10]这个点的sum如果不pushup就不会更新。所以还是要回溯的。
区间覆盖参数
query和update中都有待查询/修改区间的左右边界x,y。可能会疑惑为什么递归时传参一直写死x,y,而不是[x,mid]和[mid+1,y]。(比如查询时)
这个写法在左半部分是没问题的,即[x,mid],因为这个区间是肯定要查的。问题出在右半部分,举个例子,假设我们要查[5,9]:
- mid[1,10]=5,分裂为[5,5]和[6,9]
- mid[1,5]=3,这个时候[mid+1,y] = [4,5],给下一层传参的x=4,y=5,这样if检查发现[4,5]正好和[x,y]重叠,那么就会多查一个[4,4]的区间。
换句话说,如果我们的x过大,甚至比mid+1还大,这样写查询范围就会多查了。
那么接下来就可能会想了:如果写成
query(self.rc(curr),min(y,max(mid+1,x)),y)
行不行?
- 首先,max(mid+1,x)保证了左边界不会爆掉。
- 其次,min(max(mid+1,x),y)保证了左边界不会把右边界爆掉。
其实还是不行。从LuoGu上下了数据,发现这一组会WA:
8 10
640 591 141 307 942 58 775 133
2 1 5
2 3 8
2 3 6
2 5 8
2 4 8
1 4 8 60
2 1 6
2 5 8
1 3 7 15
1 2 6 86
其中第一个操作(2,1,5)就错了,答案是2621,我的程序跑出来2679。正好多了第六个元素58。
- 因为根节点左子树跑满了,所以左子树肯定不用看了,对的。
- [1,8] mid4,y=5,右子树要跑,x=5,y=5。
- [5,8] mid6,y=5,左子树要跑,x=5,y=6
这下就多跑了个w[6]了。
因此我们只要改成:
if x<=mid:
res += self.query(self.lc(curr),x,max(x,min(y,mid)))
if y>mid:
res += self.query(self.rc(curr),min(y,max(mid+1,x)),y)
即可。
但这么写也太麻烦了,还不如一直维持区间覆盖的参数不变呢。也就是(x,y)。
动态开点线段树
有时候整个区间特别长,比如长度可能是1e9、1e10这种。然后并不是每个单位区间都有值(或者说被赋值/修改),这个时候再开静态的4N的数组可能会MLE。所以就有了动态开点线段树。
在OI-Wiki中,对动态开点线段树核心思想的描述是:
总之,动态开点线段树的核心思想就是:结点只有在有需要的时候才被创建。
即pure demanding。
class Node:
def __init__(self,left,right,l:int,r:int,sum:int,add:int) -> None:
self.left = left
self.right = right
self.l = l
self.r = r
self.sum = sum
self.add = add
def __str__(self)->str:
return f'range:[{self.l},{self.r}],sum:{self.sum}'
class DynamicSegTree:
def __init__(self,lx:int,ry:int) -> None:
self.root = Node(None,None,lx,ry,0,0)
def pushdown(self,node:Node):
mid = (node.l+node.r)>>1
if node.left is None:
node.left = Node(None,None,node.l,mid,0,0)
if node.right is None:
node.right = Node(None,None,mid+1,node.r,0,0)
if node.add:
node.left.sum += (node.left.r-node.left.l+1)*node.add
node.right.sum += (node.right.r-node.right.l+1)*node.add
node.left.add += node.add
node.right.add += node.add
node.add = 0
def pushup(self,node:Node):
node.sum = node.left.sum + node.right.sum
def update(self,l:int,r:int,diff:int,node:Node):
if l<=node.l and node.r<=r:
node.sum += (node.r-node.l+1)*diff
node.add += diff
return
self.pushdown(node)
mid = (node.l+node.r)>>1
if l<=mid:
self.update(l,r,diff,node.left)
if r>mid:
self.update(l,r,diff,node.right)
self.pushup(node)
def query(self,l:int,r:int,node:Node)->int:
if l<=node.l and node.r<=r:
return node.sum
self.pushdown(node)
res = 0
mid = (node.l+node.r)>>1
if l<=mid:
res += self.query(l,r,node.left)
if r>mid:
res += self.query(l,r,node.right)
return res
新增左右子树的时机
在pushdown中,有这么一段:
def pushdown(self,node:Node):
mid = (node.l+node.r)>>1
if node.left is None:
node.left = Node(None,None,node.l,mid,0,0)
if node.right is None:
node.right = Node(None,None,mid+1,node.r,0,0)
if node.add:
# ......
这里的问题是,为什么不把更新左右子树的时机再拖一下,放到if node.add里面去?
这是因为下面紧接着就有:
mid = (node.l+node.r)>>1
if l<=mid:
self.update(l,r,diff,node.left)
if r>mid:
self.update(l,r,diff,node.right)
如果之前没有懒标记就不新增左右子树,这里就会访问到None。举个例子,假设当前节点[1,5],更新[2,7]:
- mid = 3,分裂为[2,3],[4,5]。且此时当前节点add=0
- 1≤3,进左子树更新,由于pushdown没更新,那么node.left=None。完蛋。
所以只要会分裂,无论进哪个子树,直接新增就行了。
1. LC 699 掉落的方块
这题1e8+1e6的数据范围,开4N就MLE了。所以板子用动态开点的线段树。
比较重要的是这个线段树板子怎么改。它的意思不是给区间里每个数上加上一个数,而是把区间里所有数改成一个数。所以这里应该是重新赋值,而不是累加。
def update(self,l:int,r:int,diff:int,node:Node):
if l<=node.l and node.r<=r:
node.sum = diff
node.add = diff
return
def pushdown(self,node:Node):
mid = (node.l+node.r)>>1
if node.left is None:
node.left = Node(None,None,node.l,mid,0,0)
if node.right is None:
node.right = Node(None,None,mid+1,node.r,0,0)
if node.add:
node.left.sum = node.add
node.right.sum = node.add
node.left.add = node.right.add = node.add
node.add = 0
对于回溯(pushup)我们可以发现,当一个新的方块落下时,他应该查找[l,r]中最大的高度并且累加上去。所以回溯时,我们取左右子树的较大值作为父节点的sum。
def pushup(self,node:Node):
node.sum = max(node.left.sum,node.right.sum)
以上图为例,假设方块是:
[2,1]
[2,9]
[1,8]
这里另外要提的一点是,因为贴边不能叠,所以我们直接对右边界减一。这样就不会重复累加贴边的情况了。例如第一个方块[x,y],那么[x:x+y-1],第二个方块[z,x-z],则[z:x-1],但实际上这俩是贴边的。
所以:
- 首先我们先修改[2,2]的值为1
- 然后查询[2,10]区间的最大值,发现是[2,2]的1,给他加上9变成10,更新[2,10]区间的所有元素为10。
这个时候[1,2]的最大值很显然为10(例如我们要在[1,2]上放一个方块,他就得从10开始算了),但[1,1]的值为0,[2,2]的值为10,因此取左右子树最大值,也即10。
具体实现就是套板子加上l+s-1的贴边处理:
from typing import List
class Node:
def __init__(self,left,right,l:int,r:int,sum:int,add:int) -> None:
self.left = left
self.right = right
self.l = l
self.r = r
self.sum = sum
self.add = add
class DynamicSegTree:
def __init__(self,lx:int,ry:int) -> None:
self.root = Node(None,None,lx,ry,0,0)
def pushdown(self,node:Node):
mid = (node.l+node.r)>>1
if node.left is None:
node.left = Node(None,None,node.l,mid,0,0)
if node.right is None:
node.right = Node(None,None,mid+1,node.r,0,0)
if node.add:
node.left.sum = node.add
node.right.sum = node.add
node.left.add = node.right.add = node.add
node.add = 0
def pushup(self,node:Node):
node.sum = max(node.left.sum,node.right.sum)
def update(self,l:int,r:int,diff:int,node:Node):
if l<=node.l and node.r<=r:
node.sum = diff
node.add = diff
return
self.pushdown(node)
mid = (node.l+node.r)>>1
if l<=mid:
self.update(l,r,diff,node.left)
if r>mid:
self.update(l,r,diff,node.right)
self.pushup(node)
def query(self,l:int,r:int,node:Node)->int:
if l<=node.l and node.r<=r:
return node.sum
self.pushdown(node)
res = 0
mid = (node.l+node.r)>>1
if l<=mid:
res = max(res,self.query(l,r,node.left))
if r>mid:
res = max(res,self.query(l,r,node.right))
return res
class Solution:
def fallingSquares(self, positions: List[List[int]]) -> List[int]:
mx = max(l+s for l,s in positions)+1
st = DynamicSegTree(1,mx)
ans = []
tmp = 0
for l,s in positions:
r = l+s-1
h = st.query(l,r,st.root)+s
tmp = max(h,tmp)
ans.append(tmp)
st.update(l,r,h,st.root)
return ans