树链剖分讲解 很清楚的一篇博客
下面是链接
https://www.cnblogs.com/chinhhh/p/7965433.html
模板
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define int ll
const int maxn=4e5+10;
int n,m,r,mod,cnt,tot=0;
int w[maxn],wt[maxn],top[maxn],head[maxn];
int ls(int x)
{
return x<<1;
}
int rs(int x)
{
return x<<1|1;
}
struct xy{
int next;
int to;
int dis;
}e[maxn];
struct node{
int val,laz;
}a[maxn<<2];
int son[maxn],id[maxn],fa[maxn],dep[maxn],siz[maxn];
inline void add(int x,int y,int d)
{
e[++cnt].dis=d;
e[cnt].to=y;
e[cnt].next=head[x];
head[x]=cnt;
}
inline void pushdown(int x,int l,int r)
{
int mid=(l+r)/2;
a[ls(x)].laz+=a[x].laz;
a[rs(x)].laz+=a[x].laz;
a[ls(x)].val+=a[x].laz*(mid-l+1);
a[rs(x)].val+=a[x].laz*(r-mid);
a[ls(x)].val%=mod;
a[rs(x)].val%=mod;
a[x].laz=0;
}
void pushup(int x)
{
a[x].val=(a[ls(x)].val+a[rs(x)].val)%mod;
}
void build(int x,int l,int r)
{
if(l==r)
{
a[x].val=wt[l]%mod;
return ;
}
int mid=(l+r)/2;
build(ls(x),l,mid);
build(rs(x),mid+1,r);
pushup(x);
}
int query(int l,int r,int nl,int nr,int x)
{
if(nl>=l&&nr<=r)
{
return a[x].val%mod;
}
pushdown(x,nl,nr);
int mid=(nl+nr)/2;
ll res=0;
if(l<=mid)
{
res+=query(l,r,nl,mid,ls(x));
}
if(r>mid)
{
res+=query(l,r,mid+1,nr,rs(x));
}
pushup(x);
return res%mod;
}
void updata(int l,int r,int nl,int nr,int x,int k)
{
if(nl>=l&&nr<=r)
{
a[x].laz+=k;
a[x].val+=k*(nr-nl+1);
a[x].val%=mod;
return ;
}
pushdown(x,nl,nr);
int mid=(nl+nr)/2;
int res=0;
if(l<=mid)
{
updata(l,r,nl,mid,ls(x),k);
}
if(r>mid)
{
updata(l,r,mid+1,nr,rs(x),k);
}
pushup(x);
}
int qrange(int x,int y)
{
ll ans=0;
while(top[x]!=top[y]){//当两个点不在同一条链上
if(dep[top[x]]<dep[top[y]])swap(x,y);//把x点改为所在链顶端的深度更深的那个点
ll res=0;
res=query(id[top[x]],id[x],1,n,1);//ans加上x点到x所在链顶端 这一段区间的点权和
ans+=res;
ans%=mod;//按题意取模
x=fa[top[x]];//把x跳到x所在链顶端的那个点的上面一个点
}
//直到两个点处于一条链上
if(dep[x]>dep[y])swap(x,y);//把x点深度更深的那个点
ans+=query(id[x],id[y],1,n,1);//这时再加上此时两个点的区间和即可
//cout<<ans<<endl;
return ans%mod;
}
void updrange(int x,int y,int k)
{
k%=mod;
while(top[x]!=top[y]){//当两个点不在同一条链上
if(dep[top[x]]<dep[top[y]])swap(x,y);//把x点改为所在链顶端的深度更深的那个点
updata(id[top[x]],id[x],1,n,1,k);//ans加上x点到x所在链顶端 这一段区间的点权和
x=fa[top[x]];//把x跳到x所在链顶端的那个点的上面一个点
}
//直到两个点处于一条链上
if(dep[x]>dep[y])swap(x,y);//把x点深度更深的那个点
updata(id[x],id[y],1,n,1,k);//这时再加上此时两个点的区间和即可
return ;
}
int qson(int x)
{
//cout<<id[x]<<" "<<id[x]+siz[x]-1<<endl;
return query(id[x],id[x]+siz[x]-1,1,n,1);
}
void updson(int x,int k)
{
updata(id[x],id[x]+siz[x]-1,1,n,1,k);
}
void dfs1(int x,int f,int deep)
{
dep[x]=deep;
fa[x]=f;
siz[x]=1;
int maxson=-1;
for(int i=head[x];i;i=e[i].next)
{
int y=e[i].to;
if(y==f)continue;
dfs1(y,x,deep+1);
siz[x]+=siz[y];
if(siz[y]>maxson)
{
son[x]=y;
maxson=siz[y];
}
}
}
void dfs2(int x,int topf)
{
id[x]=++tot;
wt[tot]=w[x];
top[x]=topf;
if(!son[x])return ;
dfs2(son[x],topf);
for(int i=head[x];i;i=e[i].next)
{
int y=e[i].to;
if(y==fa[x]||y==son[x])continue;
dfs2(y,y);
}
}
signed main()
{
cin>>n>>m>>r>>mod;
for(int i=1;i<=n;i++)cin>>w[i];
int a,b;
for(int i=1;i<n;i++)
{
cin>>a>>b;
add(a,b,1);
add(b,a,1);
}
dfs1(r,0,1);
dfs2(r,r);
build(1,1,n);
//cout<<tot<<endl;
while(m--)
{
int k,x,y,z;
cin>>k;
//cout<<m<<endl;
if(k==1)
{
cin>>x>>y>>z;
updrange(x,y,z);
}
else if(k==2)
{
cin>>x>>y;
cout<<qrange(x,y)<<endl;
}
else if(k==3)
{
cin>>x>>y;
updson(x,y);
}
else
{
cin>>x;
cout<<qson(x)<<endl;
}
}
return 0;
}