0、基本模板
0.1单点修改
class Node:
def __init__(self,l,r,s):
self.l=l
self.r=r
self.s=s
def pushup(u,l,r):
u.s = l.s + r.s
def build(u,l,r):
if l==r:tr[u]=Node(l,r,a[l])
else:
mid = l+r>>1
tr[u]=Node(l,r,0)
build(u<<1,l,mid); build(u<<1|1,mid+1,r)
pushup(tr[u],tr[u<<1],tr[u<<1|1])
def modify(u,x,v):
if tr[u].l==tr[u].r==x:
tr[u].v=v
else:
mid = tr[u].l+tr[u].r>>1
if x<=mid:modify(u<<1,l,r)
else:modify(u<<1|1,l,r)
pushup(tr[u],tr[u<<1],tr[u<<1|1])
def query(u,l,r):
if l<=tr[u].l<=tr[u].r<=r:
return tr[u]
else:
mid = tr[u].l+tr[u].r>>1
if mid>=r:return query(u<<1,l,r)
elif mid<l:return query(u<<1|1,l,r)
else:
res = Node(0,0,0)
left = query(u<<1,l,r)
right = query(u<<1|1,l,r)
pushup(res,left,right)
return res
tr = [None]*(N*4)
build(1,1,n)
- 维护的属性可以是:
s
u
m
、
m
i
n
、
m
a
x
、
g
c
d
、
l
c
m
sum、min、max、gcd、lcm
sum、min、max、gcd、lcm、区间积、区间异或和、区间异或和、区间最大值个数、区间最小值个数等
class Node:
def __init__(self, l, r, s, add, min_val, max_val, gcd_val, lcm_val, product_val, xor_val, max_count, min_count):
self.l = l
self.r = r
self.s = s
self.add = add
self.min_val = min_val
self.max_val = max_val
self.gcd_val = gcd_val
self.lcm_val = lcm_val
self.product_val = product_val
self.xor_val = xor_val
self.max_count = max_count
self.min_count = min_count
def pushup(u, l, r):
u.s = l.s + r.s
u.min_val = min(l.min_val, r.min_val)
u.max_val = max(l.max_val, r.max_val)
u.gcd_val = gcd(l.gcd_val, r.gcd_val)
u.lcm_val = lcm(l.lcm_val, r.lcm_val)
u.product_val = l.product_val * r.product_val
u.xor_val = l.xor_val ^ r.xor_val
if l.max_val == r.max_val:
u.max_count = l.max_count + r.max_count
else:
u.max_count = l.max_count if l.max_val > r.max_val else r.max_count
if l.min_val == r.min_val:
u.min_count = l.min_count + r.min_count
else:
u.min_count = l.min_count if l.min_val < r.min_val else r.min_count
0.2区间修改
class Node:
def __init__(self,l,r,s,add):
self.l=l
self.r=r
self.s=s
self.add=add
def pushup(u,l,r):
u.s = l.s+r.s
def pushdown(u,l,r):
if u.add:
l.add += u.add; l.s += (l.r-l.l+1)*u.add
r.add += u.add; r.s += (r.r-r.l+1)*u.add
u.add=0
def build(u,l,r):
if l==r:
tr[u]=Node(l,r,a[l],0)
else:
mid = l+r>>1
tr[u]=Node(l,r,0,0)
build(u<<1,l,mid); build(u<<1|1,mid+1,r)
pushup(tr[u],tr[u<<1],tr[u<<1|1])
def modify(u,l,r,x):
if l<=tr[u].l<=tr[u].r<=r:
tr[u].s += (tr[u].r-tr[u].l+1)*x
tr[u].add += x
else:
pushdown(tr[u],tr[u<<1],tr[u<<1|1])
mid = tr[u].l + tr[u].r>>1
if mid>=l:modify(u<<1,l,r,x)
if mid<r:modify(u<<1|1,l,r,x)
pushup(tr[u],tr[u<<1],tr[u<<1|1])
def query(u,l,r):
if l<=tr[u].l<=tr[u].r<=r:
return tr[u]
else:
pushdown(tr[u],tr[u<<1],tr[u<<1|1])
mid = tr[u].l+tr[u].r>>1
if mid>=r:return query(u<<1,l,r)
elif mid<l:return query(u<<1|1,l,r)
else:
res = Node(0,0,0,0)
left = query(u<<1,l,r)
right = query(u<<1|1,l,r)
pushup(res,left,right)
return res
tr = [None]*(N*4)
build(1,1,n)
class Node:
def __init__(self, l, r, s, add, min_val, max_val, gcd_val, lcm_val, product_val, xor_val, max_count, min_count):
self.l = l
self.r = r
self.s = s
self.add = add
self.min_val = min_val
self.max_val = max_val
self.gcd_val = gcd_val
self.lcm_val = lcm_val
self.product_val = product_val
self.xor_val = xor_val
self.max_count = max_count
self.min_count = min_count
def pushup(u,l,r):
u.s = l.s + r.s
u.min_val = min(l.min_val, r.min_val)
u.max_val = max(l.max_val, r.max_val)
u.gcd_val = gcd(l.gcd_val, r.gcd_val)
u.lcm_val = lcm(l.lcm_val, r.lcm_val)
u.product_val = l.product_val * r.product_val
u.xor_val = l.xor_val ^ r.xor_val
if l.max_val == r.max_val:
u.max_count = l.max_count + r.max_count
else:
u.max_count = l.max_count if l.max_val > r.max_val else r.max_count
if l.min_val == r.min_val:
u.min_count = l.min_count + r.min_count
else:
u.min_count = l.min_count if l.min_val < r.min_val else r.min_count
def modify(u, l, r, x):
if l <= tr[u].l <= tr[u].r <= r:
tr[u].s += (tr[u].r - tr[u].l + 1) * x
tr[u].min_val += x
tr[u].max_val += x
tr[u].gcd_val = x
tr[u].lcm_val = x
tr[u].product_val *= x
tr[u].xor_val ^= x
if x > 0 and tr[u].max_val == tr[u].max_val + x:
tr[u].max_count += 1
if x < 0 and tr[u].min_val == tr[u].min_val + x:
tr[u].min_count += 1
tr[u].add += x
else:
pushdown(tr[u],tr[u<<1],tr[u<<1|1])
mid = tr[u].l + tr[u].r>>1
if mid>=l:modify(u<<1,l,r,x)
if mid<r:modify(u<<1|1,l,r,x)
pushup(tr[u],tr[u<<1],tr[u<<1|1])