板子题链接:https://www.luogu.org/problemnew/show/P3384
个人理解:树链剖分最重要的作用是实现了重链上的节点序号都是连续的,这样对于更改和查询一条树链就变成了区间操作,可以用线段树实现。
#include<bits/stdc++.h>
using namespace std;
#define lson rt<<1
#define rson rt<<1|1
const int maxn=2e5+5;
int n,m,r,p;
int val[maxn],nxb[maxn];
struct node
{
int u,v,nxt;
node(){}
node(int u,int v,int nxt):u(u),v(v),nxt(nxt){}
}edge[maxn<<2];
int cnt;
int head[maxn<<2];
int dep[maxn],fa[maxn],sonsize[maxn],hson[maxn],topfa[maxn];//深度,父亲,子树的节点数(包括根),重儿子,链顶
int newval[maxn];
void addedge(int u,int v)//链式前向星
{
edge[++cnt]=node(u,v,head[u]);
head[u]=cnt;
}
int tree[maxn<<2],lazy[maxn<<2];//线段树数组即懒惰标记数组
void bulid(int rt,int l,int r)//建树
{
if(l==r)
{
tree[rt]=newval[++cnt]%p;
return;
}
int mid=(l+r)>>1;
bulid(lson,l,mid);
bulid(rson,mid+1,r);
tree[rt]=(tree[lson]+tree[rson])%p;
}
void pushdown(int rt,int l,int r)//向下更新
{
if(lazy[rt]==0)return;
lazy[lson]+=lazy[rt];
lazy[rson]+=lazy[rt];
int mid=(l+r)>>1;
tree[lson]+=(mid-l+1)*lazy[rt];
tree[rson]+=(r-mid)*lazy[rt];
tree[lson]%=p;
tree[rson]%=p;
lazy[rt]=0;
}
void updata(int rt,int l,int r,int x,int y,int v)
{
if(l>=x&&r<=y)
{
lazy[rt]+=v;
tree[rt]=(tree[rt]+(r-l+1)*v)%p;
return;
}
int mid=(l+r)>>1;
pushdown(rt,l,r);
if(y<=mid)updata(lson,l,mid,x,y,v);
else if(x>=mid+1)updata(rson,mid+1,r,x,y,v);
else updata(lson,l,mid,x,mid,v),updata(rson,mid+1,r,mid+1,y,v);
tree[rt]=(tree[lson]+tree[rson])%p;
}
int query(int rt,int l,int r,int x,int y)
{
if(l>=x&&r<=y)
{
return tree[rt];
}
int mid=(l+r)>>1;
pushdown(rt,l,r);
int ret=0;
if(y<=mid)ret=query(lson,l,mid,x,y)%p;
else if(x>=mid+1)ret=query(rson,mid+1,r,x,y)%p;
else ret=(query(lson,l,mid,x,mid)+query(rson,mid+1,r,mid+1,y))%p;
return ret;
}
void dfs1(int u,int f,int deep)//处理点的深度、父亲、重儿子及以某一点为根的子树大小
{
dep[u]=deep;
fa[u]=f;
sonsize[u]=1;
hson[u]=0;
int maxson=-1;
for(int i=head[u];~i;i=edge[i].nxt)
{
int v=edge[i].v;
if(v==f)continue;
dfs1(v,u,deep+1);
sonsize[u]+=sonsize[v];
if(sonsize[v]>maxson)maxson=sonsize[v],hson[u]=v;
}
}
void dfs2(int u,int topf)//给每个点一个新的下标,处理链顶
{
nxb[u]=++cnt;
newval[cnt]=val[u];
topfa[u]=topf;
if(!hson[u])return;
dfs2(hson[u],topf);
for(int i=head[u];~i;i=edge[i].nxt)
{
int v=edge[i].v;
if(v==fa[u]||v==hson[u])continue;
dfs2(v,v);
}
}
int sum_subtree(int u)
{
return query(1,1,n,nxb[u],nxb[u]+sonsize[u]-1);
}
void up_subtree(int u,int k)
{
updata(1,1,n,nxb[u],nxb[u]+sonsize[u]-1,k);
}
int sum_lian(int u,int v)
{
int ret=0;
while(topfa[u]!=topfa[v])
{
if(dep[topfa[u]]<dep[topfa[v]])swap(u,v);
ret+=query(1,1,n,nxb[topfa[u]],nxb[u]);
ret%=p;
u=fa[topfa[u]];
}
if(dep[u]>dep[v])swap(u,v);//深度浅的下表小
ret=(ret+query(1,1,n,nxb[u],nxb[v]))%p;
return ret;
}
void up_lian(int u,int v,int k)
{
while(topfa[u]!=topfa[v])
{
if(dep[topfa[u]]<dep[topfa[v]])swap(u,v);
updata(1,1,n,nxb[topfa[u]],nxb[u],k);
u=fa[topfa[u]];
}
if(dep[u]>dep[v])swap(u,v);
updata(1,1,n,nxb[u],nxb[v],k);
}
int main()
{
scanf("%d%d%d%d",&n,&m,&r,&p);
for(int i=1;i<=n;i++)scanf("%d",&val[i]);
int x,y,z;
cnt=0;
memset(head,-1,sizeof(head));
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
addedge(x,y);
addedge(y,x);
}
cnt=0;
dfs1(r,0,1);
dfs2(r,r);
cnt=0;
bulid(1,1,n);
int op;
while(m--)
{
scanf("%d",&op);
if(op==1)
{
scanf("%d%d%d",&x,&y,&z);
up_lian(x,y,z);
}
else if(op==2)
{
scanf("%d%d",&x,&y);
printf("%d\n",sum_lian(x,y)%p);
}
else if(op==3)
{
scanf("%d%d",&x,&y);
up_subtree(x,y);
}
else
{
scanf("%d",&x);
printf("%d\n",sum_subtree(x)%p);
}
}
}