Python 线段树:区间求和,实现y=ax+b的映射

目录

1.原理

2.完整代码

3.应用


1.原理

懒标记的初始化:

self.tag=["1#0"]*(n<<2+1)

这里的1和0对于的y=ax+b中的a和b,显然当a=1,b=0时,y=x

push_up()

def push_up(self,p):
        self.tree[p]=self.tree[p<<1]+self.tree[p<<1|1] 

线段树维护的是区间和,节点维护两个子节点的值的和

push_down()和update()原理

 最顶层的结点的值显然是

s = x_{1}+x_{2}+x_{3}+x_{4}=\sum_{1}^{4}x_{i}

如果将这个区间的每个值都用ax+b映射,那么最顶层的结点应该为

s=\sum _{1}^{4}a*x+b = a*\sum_{1}^{4}x + b * (4 - 1 + 1)

因此,每次更新tree的节点时,我们需要用到结点当前的值

self.tree[p]=a*self.tree[p]+b*(pr-pl+1)

这里的pr-pl+1就是区间长度

更新懒标记tag时,假设当前懒标记的值为"ai#bi", update的参数为a,b

y = a*(a_{i} * x + b_{i}) + b = (a * a_{i} )* x + (a * b_{i} + b)

ai,bi=map(int,self.tag[p<<1].split("#"))
self.tag[p<<1]="{}#{}".format(a*ai,a*bi+b)

因此:

update():

def update(self,p,pl,pr,L,R,a,b):  # y=a*x+b
    if L<=pl and pr<=R:
        self.tree[p]=a*self.tree[p]+b*(pr-pl+1)
        ai,bi=map(int,self.tag[p].split("#"))
        self.tag[p]="{}#{}".format(a*ai,a*bi+b)
        return
    mid=(pl+pr)>>1
    self.push_down(p,mid-pl+1,pr-mid)
    if L<=mid:
        self.update(p<<1,pl,mid,L,R,a,b)
    if R>mid:
        self.update(p<<1|1,mid+1,pr,L,R,a,b)
    self.push_up(p)

push_down():

def push_down(self,p,pl,pr):
    if self.tag[p]!="1#0": 
        a,b=map(int,self.tag[p].split("#"))

        ai,bi=map(int,self.tag[p<<1].split("#"))
        self.tag[p<<1]="{}#{}".format(a*ai,a*bi+b)

        ai,bi=map(int,self.tag[p<<1|1].split("#"))
        self.tag[p<<1|1]="{}#{}".format(a*ai,a*bi+b)

        self.tree[p<<1]=a*self.tree[p<<1]+b*pl
        self.tree[p<<1|1]=a*self.tree[p<<1|1]+b*pr
        self.tag[p]="1#0"

2.完整代码

class SegmentTree(object):
    def __init__(self,n,nums):
        self.a=[0]+nums
        self.tree=[0]*(n<<2+1)
        self.tag=["1#0"]*(n<<2+1)

    def push_up(self,p):
        self.tree[p]=self.tree[p<<1]+self.tree[p<<1|1] 

    def push_down(self,p,pl,pr):
        if self.tag[p]!="1#0": 
            a,b=map(int,self.tag[p].split("#"))

            ai,bi=map(int,self.tag[p<<1].split("#"))
            self.tag[p<<1]="{}#{}".format(a*ai,a*bi+b)

            ai,bi=map(int,self.tag[p<<1|1].split("#"))
            self.tag[p<<1|1]="{}#{}".format(a*ai,a*bi+b)

            self.tree[p<<1]=a*self.tree[p<<1]+b*pl
            self.tree[p<<1|1]=a*self.tree[p<<1|1]+b*pr
            self.tag[p]="1#0"

    def build(self,p,pl,pr):  
        if pl==pr:
            self.tree[p]=self.a[pl]
            return
        mid=(pl+pr)>>1
        self.build(p<<1,pl,mid)
        self.build(p<<1|1,mid+1,pr)
        self.push_up(p)

    def update(self,p,pl,pr,L,R,a,b):  # y=a*x+b
        if L<=pl and pr<=R:
            self.tree[p]=a*self.tree[p]+b*(pr-pl+1)
            ai,bi=map(int,self.tag[p].split("#"))
            self.tag[p]="{}#{}".format(a*ai,a*bi+b)
            return
        mid=(pl+pr)>>1
        self.push_down(p,mid-pl+1,pr-mid)
        if L<=mid:
            self.update(p<<1,pl,mid,L,R,a,b)
        if R>mid:
            self.update(p<<1|1,mid+1,pr,L,R,a,b)
        self.push_up(p)
        
    def query(self,p,pl,pr,L,R):
        if L<=pl and pr<=R:
            return self.tree[p]
        ans=0
        mid=(pl+pr)>>1
        self.push_down(p,mid-pl+1,pr-mid)
        if L<=mid:
            ans+=self.query(p<<1,pl,mid,L,R)
        if R>mid:
            ans+=self.query(p<<1|1,mid+1,pr,L,R)
        return ans

测试用例:

if __name__ == "__main__":
    n=8
    nums=[1,4,2,8,5,7,9,3]
    st=SegmentTree(n,nums)
    st.build(1,1,n)
    print(st.query(1,1,n,2,5))

    st.update(1,1,n,4,6,3,4) 
    print(st.query(1,1,n,3,5)) 
    
    st.update(1,1,n,2,4,2,1)
    print(st.query(1,1,n,2,7)) 

tree = 1,4,2,8,5,7,9,3 

第一次查询:区间25的和

s = \sum_{2}^{5}x_{i}=x_{2}+x_{3}+x_{4}+x_{5}=19

第一次更新:将区间46的值通过y=3x+4映射 

y_{4}=3*x_{4}+4=28

y_{5}=3*x_{5}+4=19

y_{6}=3*x_{6}+4=25

tree = 1,4,2,28,19,25,9,3

第二次查询:区间35的和

s = \sum_{3}^{5}x_{i}=x_{3}+x_{4}+x_{5} = 2+28+19=49

第二次更新:区间24的值通过y=2x+1映射

y_{2}=2*x_2+1=9

y_{3}=2*x_{3}+1=5

y_{4}=2*x_{4}+1=57

tree = 1,9,5,57,19,25,9,3

第三次查询:区间27的和

s=\sum_{2}^{7}x_{i}=x_{2}+x_{3}+x_{4}+x_{5}+x_{6}+x_{7}=124

3.应用

a=0,b=k时,可以实现区间上的元素替换为值k

a=1,b=k时,可以实现区间上的元素加上值k

a=k,b=0时,可以实现区间上的元素乘上值k

对应于洛谷的一道题目:P2023 [AHOI2009] 维护序列

代码:

class SegmentTree(object):
    def __init__(self,n,nums):
        self.a=[0]+nums
        self.tree=[0]*(n<<2+1)
        self.tag=["1#0"]*(n<<2+1)

    def push_up(self,p):
        self.tree[p]=self.tree[p<<1]+self.tree[p<<1|1]

    def push_down(self,p,pl,pr):
        if self.tag[p]!="1#0":
            a,b=map(int,self.tag[p].split("#"))

            ai,bi=map(int,self.tag[p<<1].split("#"))
            self.tag[p<<1]="{}#{}".format(a*ai,a*bi+b)

            ai,bi=map(int,self.tag[p<<1|1].split("#"))
            self.tag[p<<1|1]="{}#{}".format(a*ai,a*bi+b)

            self.tree[p<<1]=a*self.tree[p<<1]+b*pl
            self.tree[p<<1|1]=a*self.tree[p<<1|1]+b*pr
            self.tag[p]="1#0"

    def build(self,p,pl,pr):   # 建树
        if pl==pr:
            self.tree[p]=self.a[pl]
            return
        mid=(pl+pr)>>1
        self.build(p<<1,pl,mid)
        self.build(p<<1|1,mid+1,pr)
        self.push_up(p)

    def update(self,p,pl,pr,L,R,a,b):  # y=a*x+b
        if L<=pl and pr<=R:
            self.tree[p]=a*self.tree[p]+b*(pr-pl+1) 
            ai,bi=map(int,self.tag[p].split("#"))
            self.tag[p]="{}#{}".format(a*ai,a*bi+b)
            return
        mid=(pl+pr)>>1
        self.push_down(p,mid-pl+1,pr-mid)
        if L<=mid:
            self.update(p<<1,pl,mid,L,R,a,b)
        if R>mid:
            self.update(p<<1|1,mid+1,pr,L,R,a,b)
        self.push_up(p)
        
    def query(self,p,pl,pr,L,R):
        if L<=pl and pr<=R:
            return self.tree[p]
        ans=0
        mid=(pl+pr)>>1
        self.push_down(p,mid-pl+1,pr-mid)
        if L<=mid:
            ans+=self.query(p<<1,pl,mid,L,R)
        if R>mid:
            ans+=self.query(p<<1|1,mid+1,pr,L,R)
        return ans

n,k=map(int,input().split())
nums=[int(i) for i in input().split()]
st=SegmentTree(n,nums)
st.build(1,1,n)
m=int(input())
for _ in range(m):
    seq=input()
    if seq[0]=='1':
        _,t,g,c=map(int,seq.split())
        st.update(1,1,n,t,g,c,0)
    elif seq[0]=='2':
        _,t,g,c=map(int,seq.split())
        st.update(1,1,n,t,g,1,c)
    elif seq[0]=='3':
        _,t,g=map(int,seq.split())
        print(st.query(1,1,n,t,g)%k)

但是受限于python运算速度太慢,因此大部分都超时,不过没有超时的都得到正解:

如有错误,欢迎指正

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值