目录
1.原理
懒标记的初始化:
self.tag=["1#0"]*(n<<2+1)
这里的1和0对于的y=ax+b中的a和b,显然当a=1,b=0时,y=x
push_up()
def push_up(self,p):
self.tree[p]=self.tree[p<<1]+self.tree[p<<1|1]
线段树维护的是区间和,节点维护两个子节点的值的和
push_down()和update()原理
最顶层的结点的值显然是
如果将这个区间的每个值都用ax+b映射,那么最顶层的结点应该为
因此,每次更新tree的节点时,我们需要用到结点当前的值
self.tree[p]=a*self.tree[p]+b*(pr-pl+1)
这里的pr-pl+1就是区间长度
更新懒标记tag时,假设当前懒标记的值为"ai#bi", update的参数为a,b
ai,bi=map(int,self.tag[p<<1].split("#"))
self.tag[p<<1]="{}#{}".format(a*ai,a*bi+b)
因此:
update():
def update(self,p,pl,pr,L,R,a,b): # y=a*x+b
if L<=pl and pr<=R:
self.tree[p]=a*self.tree[p]+b*(pr-pl+1)
ai,bi=map(int,self.tag[p].split("#"))
self.tag[p]="{}#{}".format(a*ai,a*bi+b)
return
mid=(pl+pr)>>1
self.push_down(p,mid-pl+1,pr-mid)
if L<=mid:
self.update(p<<1,pl,mid,L,R,a,b)
if R>mid:
self.update(p<<1|1,mid+1,pr,L,R,a,b)
self.push_up(p)
push_down():
def push_down(self,p,pl,pr):
if self.tag[p]!="1#0":
a,b=map(int,self.tag[p].split("#"))
ai,bi=map(int,self.tag[p<<1].split("#"))
self.tag[p<<1]="{}#{}".format(a*ai,a*bi+b)
ai,bi=map(int,self.tag[p<<1|1].split("#"))
self.tag[p<<1|1]="{}#{}".format(a*ai,a*bi+b)
self.tree[p<<1]=a*self.tree[p<<1]+b*pl
self.tree[p<<1|1]=a*self.tree[p<<1|1]+b*pr
self.tag[p]="1#0"
2.完整代码
class SegmentTree(object):
def __init__(self,n,nums):
self.a=[0]+nums
self.tree=[0]*(n<<2+1)
self.tag=["1#0"]*(n<<2+1)
def push_up(self,p):
self.tree[p]=self.tree[p<<1]+self.tree[p<<1|1]
def push_down(self,p,pl,pr):
if self.tag[p]!="1#0":
a,b=map(int,self.tag[p].split("#"))
ai,bi=map(int,self.tag[p<<1].split("#"))
self.tag[p<<1]="{}#{}".format(a*ai,a*bi+b)
ai,bi=map(int,self.tag[p<<1|1].split("#"))
self.tag[p<<1|1]="{}#{}".format(a*ai,a*bi+b)
self.tree[p<<1]=a*self.tree[p<<1]+b*pl
self.tree[p<<1|1]=a*self.tree[p<<1|1]+b*pr
self.tag[p]="1#0"
def build(self,p,pl,pr):
if pl==pr:
self.tree[p]=self.a[pl]
return
mid=(pl+pr)>>1
self.build(p<<1,pl,mid)
self.build(p<<1|1,mid+1,pr)
self.push_up(p)
def update(self,p,pl,pr,L,R,a,b): # y=a*x+b
if L<=pl and pr<=R:
self.tree[p]=a*self.tree[p]+b*(pr-pl+1)
ai,bi=map(int,self.tag[p].split("#"))
self.tag[p]="{}#{}".format(a*ai,a*bi+b)
return
mid=(pl+pr)>>1
self.push_down(p,mid-pl+1,pr-mid)
if L<=mid:
self.update(p<<1,pl,mid,L,R,a,b)
if R>mid:
self.update(p<<1|1,mid+1,pr,L,R,a,b)
self.push_up(p)
def query(self,p,pl,pr,L,R):
if L<=pl and pr<=R:
return self.tree[p]
ans=0
mid=(pl+pr)>>1
self.push_down(p,mid-pl+1,pr-mid)
if L<=mid:
ans+=self.query(p<<1,pl,mid,L,R)
if R>mid:
ans+=self.query(p<<1|1,mid+1,pr,L,R)
return ans
测试用例:
if __name__ == "__main__":
n=8
nums=[1,4,2,8,5,7,9,3]
st=SegmentTree(n,nums)
st.build(1,1,n)
print(st.query(1,1,n,2,5))
st.update(1,1,n,4,6,3,4)
print(st.query(1,1,n,3,5))
st.update(1,1,n,2,4,2,1)
print(st.query(1,1,n,2,7))
第一次查询:区间2到5的和
第一次更新:将区间4到6的值通过y=3x+4映射
第二次查询:区间3到5的和
第二次更新:区间2到4的值通过y=2x+1映射
第三次查询:区间2到7的和
3.应用
当a=0,b=k时,可以实现区间上的元素替换为值k
当a=1,b=k时,可以实现区间上的元素加上值k
当a=k,b=0时,可以实现区间上的元素乘上值k
对应于洛谷的一道题目:P2023 [AHOI2009] 维护序列
代码:
class SegmentTree(object):
def __init__(self,n,nums):
self.a=[0]+nums
self.tree=[0]*(n<<2+1)
self.tag=["1#0"]*(n<<2+1)
def push_up(self,p):
self.tree[p]=self.tree[p<<1]+self.tree[p<<1|1]
def push_down(self,p,pl,pr):
if self.tag[p]!="1#0":
a,b=map(int,self.tag[p].split("#"))
ai,bi=map(int,self.tag[p<<1].split("#"))
self.tag[p<<1]="{}#{}".format(a*ai,a*bi+b)
ai,bi=map(int,self.tag[p<<1|1].split("#"))
self.tag[p<<1|1]="{}#{}".format(a*ai,a*bi+b)
self.tree[p<<1]=a*self.tree[p<<1]+b*pl
self.tree[p<<1|1]=a*self.tree[p<<1|1]+b*pr
self.tag[p]="1#0"
def build(self,p,pl,pr): # 建树
if pl==pr:
self.tree[p]=self.a[pl]
return
mid=(pl+pr)>>1
self.build(p<<1,pl,mid)
self.build(p<<1|1,mid+1,pr)
self.push_up(p)
def update(self,p,pl,pr,L,R,a,b): # y=a*x+b
if L<=pl and pr<=R:
self.tree[p]=a*self.tree[p]+b*(pr-pl+1)
ai,bi=map(int,self.tag[p].split("#"))
self.tag[p]="{}#{}".format(a*ai,a*bi+b)
return
mid=(pl+pr)>>1
self.push_down(p,mid-pl+1,pr-mid)
if L<=mid:
self.update(p<<1,pl,mid,L,R,a,b)
if R>mid:
self.update(p<<1|1,mid+1,pr,L,R,a,b)
self.push_up(p)
def query(self,p,pl,pr,L,R):
if L<=pl and pr<=R:
return self.tree[p]
ans=0
mid=(pl+pr)>>1
self.push_down(p,mid-pl+1,pr-mid)
if L<=mid:
ans+=self.query(p<<1,pl,mid,L,R)
if R>mid:
ans+=self.query(p<<1|1,mid+1,pr,L,R)
return ans
n,k=map(int,input().split())
nums=[int(i) for i in input().split()]
st=SegmentTree(n,nums)
st.build(1,1,n)
m=int(input())
for _ in range(m):
seq=input()
if seq[0]=='1':
_,t,g,c=map(int,seq.split())
st.update(1,1,n,t,g,c,0)
elif seq[0]=='2':
_,t,g,c=map(int,seq.split())
st.update(1,1,n,t,g,1,c)
elif seq[0]=='3':
_,t,g=map(int,seq.split())
print(st.query(1,1,n,t,g)%k)
但是受限于python运算速度太慢,因此大部分都超时,不过没有超时的都得到正解:
如有错误,欢迎指正