emmm...
虽说树剖长度较长(也还好)
但理解起来还不算太难?
根据此题写下学习笔记
树链剖分前置知识:基本知识($DFS$,链式前向星),$LCA$(至少概念),线段树
首先引入基本概念:
1、重儿子:
指每个节点(非叶子结点)的儿子中,以其儿子为根的子树中节点个数最多的儿子
易知每个节点(非叶子结点)的重儿子有且仅有一个
2、轻儿子:
指每个节点(非叶子结点)的儿子中,除重儿子之外的所有儿子
3、重边:
一个父亲连接其重儿子的边称为重边
4、轻边:
除重边之外的所有边
#include<cstdio>
#define maxn 200002
inline int read(){
int r=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')r=(r<<1)+(r<<3)+(c^48),c=getchar();
return r*f;
}
struct TR{
int l,r,sum,add;
}tr[4*maxn];
struct E{
int v,nxt;
}e[2*maxn];
int s_e,head[maxn],fa[maxn],size[maxn],son[maxn],dep[maxn],seg[maxn],rev[maxn],top[maxn],max_seg[maxn],val[maxn],s_tr,n,m,root,mod;
inline void a_e(int u,int v){
e[++s_e]=(E){v,head[u]};
head[u]=s_e;
}
inline void add(int u,int v){
a_e(u,v);
a_e(v,u);
}
void build(int p,int l,int r){
tr[p].l=l,tr[p].r=r;
if(l==r){
tr[p].sum=val[rev[l]]%mod;
return;
}
int mid=(l+r)>>1;
build(p<<1,l,mid);
build((p<<1)+1,mid+1,r);
tr[p].sum=(tr[(p<<1)].sum+tr[(p<<1)+1].sum)%mod;
}
inline void spread(int p){
if(tr[p].add){
tr[(p<<1)].sum=(long long)(tr[(p<<1)].sum+tr[p].add*(tr[(p<<1)].r-tr[(p<<1)].l+1))%mod;
tr[(p<<1)+1].sum=(long long)(tr[(p<<1)+1].sum+tr[p].add*(tr[(p<<1)+1].r-tr[(p<<1)+1].l+1))%mod;
tr[(p<<1)].add+=tr[p].add;
tr[(p<<1)+1].add+=tr[p].add;
}
tr[p].add=0;
}
void change(int p,int l,int r,int x){
if(l<=tr[p].l&&tr[p].r<=r){
tr[p].sum=(long long)(tr[p].sum+x*(tr[p].r-tr[p].l+1))%mod;
tr[p].add+=x;
return;
}
spread(p);
int mid=(tr[p].l+tr[p].r)>>1;
if(l<=mid)change(p<<1,l,r,x);
if(r>mid)change((p<<1)+1,l,r,x);
tr[p].sum=(tr[(p<<1)].sum+tr[(p<<1)+1].sum)%mod;
}
void query(int p,int l,int r,int &sum){
if(l<=tr[p].l&&tr[p].r<=r){
sum=(sum+tr[p].sum)%mod;
return;
}
spread(p);
int mid=(tr[p].l+tr[p].r)>>1;
if(l<=mid)query(p<<1,l,r,sum);
if(r>mid)query((p<<1)+1,l,r,sum);
}
inline int max(int a,int b){
return a>b?a:b;
}
void dfs_1(int u){
size[u]=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==fa[u])continue;
fa[v]=u;
dep[v]=dep[u]+1;
dfs_1(v);
size[u]+=size[v];
if(size[v]>size[son[u]])son[u]=v;
}
}
void dfs_2(int u,int t){
top[u]=t;
seg[u]=++s_tr;
rev[s_tr]=u;
max_seg[u]=s_tr;
if(son[u])dfs_2(son[u],t),max_seg[u]=max_seg[son[u]];
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(top[v])continue;
dfs_2(v,v);
max_seg[u]=max_seg[v];
}
}
inline void swap(int &a,int &b){
int c=a;
a=b;
b=c;
}
inline int ask(int u,int v){
int sum=0;
int fu=top[u],fv=top[v];
while(fu!=fv){
if(dep[fu]<dep[fv])swap(u,v),swap(fu,fv);
query(1,seg[fu],seg[u],sum);
u=fa[fu],fu=top[u];
}
if(dep[u]>dep[v])swap(u,v);
query(1,seg[u],seg[v],sum);
return sum;
}
inline void update(int u,int v,int x){
int fu=top[u],fv=top[v];
while(fu!=fv){
if(dep[fu]<dep[fv])swap(u,v),swap(fu,fv);
change(1,seg[fu],seg[u],x);
u=fa[fu],fu=top[u];
}
if(dep[u]>dep[v])swap(u,v);
change(1,seg[u],seg[v],x);
}
int main(){
n=read(),m=read(),root=read(),mod=read();
for(int i=1;i<=n;i++)val[i]=read();
for(int i=1;i<n;i++){
int u=read(),v=read();
add(u,v);
}
dfs_1(root);
dfs_2(root,root);
build(1,1,s_tr);
for(int i=1;i<=m;i++){
int op=read();
if(op==1){
int x=read(),y=read(),z=read();
update(x,y,z);
}
else if(op==2){
int x=read(),y=read();
printf("%d\n",ask(x,y));
}
else if(op==3){
int x=read(),z=read();
change(1,seg[x],max_seg[x],z);
}
else{
int x=read(),sum=0;
query(1,seg[x],max_seg[x],sum);
printf("%d\n",sum);
}
}
return 0;
}