# 【模板】重链剖分/树链剖分
## 题目描述
如题,已知一棵包含 $N$ 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
- `1 x y z`,表示将树从 $x$ 到 $y$ 结点最短路径上所有节点的值都加上 $z$。
- `2 x y`,表示求树从 $x$ 到 $y$ 结点最短路径上所有节点的值之和。
- `3 x z`,表示将以 $x$ 为根节点的子树内所有节点值都加上 $z$。
- `4 x` 表示求以 $x$ 为根节点的子树内所有节点值之和
## 输入格式
第一行包含 $4$ 个正整数 $N,M,R,P$,分别表示树的结点个数、操作个数、根节点序号和取模数(**即所有的输出结果均对此取模**)。
接下来一行包含 $N$ 个非负整数,分别依次表示各个节点上初始的数值。
接下来 $N-1$ 行每行包含两个整数 $x,y$,表示点 $x$ 和点 $y$ 之间连有一条边(保证无环且连通)。
接下来 $M$ 行每行包含若干个正整数,每行表示一个操作。
## 输出格式
输出包含若干行,分别依次表示每个操作 $2$ 或操作 $4$ 所得的结果(**对 $P$ 取模**)。
## 样例 #1
### 样例输入 #1
```
5 5 2 24
7 3 7 8 0
1 2
1 5
3 1
4 1
3 4 2
3 2 2
4 5
1 5 1 3
2 1 3
```
### 样例输出 #1
```
2
21
```
#include<bits/stdc++.h>
#include<iostream>
#include<iomanip>
#include<string>
#include<math.h>
#include<cmath>
using namespace std;
void exgcd(int a,int b,int&x,int&y){
if(b==0)
{
x=1;
y=0;
return;
}
exgcd(b,a%b,y,x);
y-=a/b*x;
}
int n,m,r,p;
int to[200007],nxt[200007],lst[200007],e;
void adde(int x,int y){
to[++e]=y;
nxt[e]=lst[x];
lst[x]=e;
}
int siz[100007],fa[100007],son[100007],dep[100007];
int top[100007],id[100007],rev[100007],cnt;
int w[100007],w2[100007];
void dfs1(int u,int f,int deep){
fa[u]=f;
siz[u]=1;
dep[u]=deep;
int mxsiz=0;
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(v==f)
{
continue;
}
dfs1(v,u,deep+1);
siz[u]+=siz[v];
if(mxsiz<siz[v])
{
mxsiz=siz[v];
son[u]=v;
}
}
}
void dfs2(int u,int ntop){
id[u]=++cnt;
w2[cnt]=w[u];
top[u]=ntop;
if(!son[u])
{
return;
}
dfs2(son[u],ntop);
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(v==fa[u]||v==son[u])
{
continue;
}
dfs2(v,v);
}
}
struct node{
long long s,lz;
}a[400007];
void push_d(int pos,int l,int r){
int k=a[pos].lz,m=(l+r)>>1;
a[pos<<1].s=(a[pos<<1].s+k*(m-l+1))%p;
a[pos<<1].lz=(a[pos<<1].lz+k);
a[pos<<1|1].s=(a[pos<<1|1].s+k*(r-m))%p;
a[pos<<1|1].lz=(a[pos<<1|1].lz+k);
a[pos].lz=0;
}
void bd(int pos,int l,int r){
if(l==r)
{
a[pos].s=w2[l];
a[pos].s%=p;
return;
}
int m=(l+r)>>1;
bd(pos<<1,l,m);
bd(pos<<1|1,m+1,r);
a[pos].s=(a[pos<<1].s+a[pos<<1|1].s)%p;
}
long long qx(int pos,int l,int r,int ql,int qr){
//cout<<a[pos].s<<" "<<a[pos].lz<<" ";
if(ql<=l&&r<=qr)
{
//cout<<pos;
return a[pos].s;
}
long long sum=0;
int m=(l+r)>>1;
push_d(pos,l,r);
//cout<<a[pos].lz;
if(m>=ql)
{
sum+=qx(pos<<1,l,m,ql,qr);
}
if(m<qr)
{
sum+=qx(pos<<1|1,m+1,r,ql,qr);
}
sum%=p;
return sum;
}
void up_d(int pos,int l,int r,int ql,int qr,int add){
add%=p;
if(ql<=l&&r<=qr)
{
a[pos].s+=add*(r-l+1);
a[pos].lz+=add;
//cout<<pos<<" "<<l<<" "<<r<<" ";
return;
}
int m=(l+r)>>1;
push_d(pos,l,r);
if(m>=ql)
{
up_d(pos<<1,l,m,ql,qr,add);
}
if(m<qr)
{
up_d(pos<<1|1,m+1,r,ql,qr,add);
}
a[pos].s=(a[pos<<1].s+a[pos<<1|1].s)%p;
}
void lj_x(int x,int y,int k){
k%=p;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])
{
swap(x,y);
}
up_d(1,1,n,id[top[x]],id[x],k);
x=fa[top[x]];
}
if(dep[x]>dep[y])
{
swap(x,y);
}
up_d(1,1,n,id[x],id[y],k);
}
long long lj_s(int x,int y){
long long ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])
{
swap(x,y);
}
ans+=qx(1,1,n,id[top[x]],id[x]);
//cout<<ans<<endl;
ans%=p;
x=fa[top[x]];
}
if(dep[x]>dep[y])
{
swap(x,y);
}
//cout<<id[x]<<" "<<id[y]<<endl;
ans+=qx(1,1,n,id[x],id[y]);
return ans%p;
}
int main(){
cin>>n>>m>>r>>p;
for(int i=1;i<=n;i++){
cin>>w[i];
}
for(int i=1;i<n;i++){
int p1,p2;
cin>>p1>>p2;
adde(p1,p2);
adde(p2,p1);
}
dfs1(r,0,1);
dfs2(r,r);
bd(1,1,n);
for(int i=1;i<=m;i++){
int type,p1,p2,p3;
cin>>type;
cin>>p1;
if(type==1)
{
cin>>p2>>p3;
lj_x(p1,p2,p3);
}
if(type==2)
{
cin>>p2;
cout<<lj_s(p1,p2)<<endl;
}
if(type==3)
{
cin>>p2;
up_d(1,1,n,id[p1],id[p1]+siz[p1]-1,p2);
}
if(type==4)
{
cout<<qx(1,1,n,id[p1],id[p1]+siz[p1]-1)<<endl;
//cout<<id[p1];
//qx(1,1,n,id[p1],id[p1]+siz[p1]-1);
}
}
return 0;
}
114.7ms,11.36MB