洛谷树链剖分模板题链接:P3384
#include<cstring>
#include<cstdio>
#include<vector>
using namespace std;
const int INF = 1e9 + 7;
const int MAX_N = 100010;
long long mode;
vector<int>edge[MAX_N];
int tree[MAX_N*4],lazy[MAX_N*4];
int dfn[MAX_N],rev_dfn[MAX_N],Start[MAX_N],End[MAX_N],tot,dep[MAX_N],par[MAX_N];
//major存储的是子树节点最多的儿子,super存储的是这条链上的顶端节点
int major[MAX_N],super[MAX_N];
int val[100010];
//处理 深度、重儿子
int dfs1(int v,int far)
{
dep[v] = dep[far]+1; par[v] = far;
int cnt=1,child=v,Max=0;
for( int u : edge[v] )
{
if( u!=far ){
int temp = dfs1(u,v);
cnt += temp;
if( temp>Max ){
Max = temp;
child = u;
}
}
}
major[v] = child;
return cnt;
}
//处理 dfs序、链头
void dfs2(int v,int far,int sup)
{
dfn[v] = Start[v] = ++tot;
rev_dfn[ dfn[v] ] = v;
super[v] = sup;
if( major[v] != v )dfs2(major[v],v,super[v]);
for( int u : edge[v] )
{
if( u!=far && u!=major[v] ){
dfs2( u,v,u );
}
}
End[v] = tot;
return ;
}
inline void push_down(int node,int left,int right)
{
if( left>right || lazy[node]==0 )return ;
else {
tree[node] += lazy[node]*(right-left+1);
tree[node] %= mode;
lazy[2*node] += lazy[node];
lazy[2*node+1] +=lazy[node];
lazy[node] = 0;
}
}
void buildTree(int node,int left,int right)
{
if( left == right ){
tree[node] = val[ rev_dfn[left] ];
return ;
}
int mid = (left+right)/2;
buildTree(2*node,left,mid);
buildTree(2*node+1,mid+1,right);
tree[node] = ( tree[2*node] + tree[2*node+1] ) % mode;
return ;
}
long long query(int x,int y,int node,int left,int right)
{
push_down(node,left,right);
if( x<=left && right<=y )
{
return tree[node];
}
long long ret = 0;
int mid = (left+right)/2;
if( mid>=x )ret += query(x,y,2*node,left,mid);
if( mid<y )ret += query(x,y,2*node+1,mid+1,right);
return ret % mode;
}
void modify(int x,int y,int node,int left,int right,int val)
{
if( x<=left && right<=y )
{
lazy[node] += val;
push_down(node,left,right);
return ;
}
push_down(node,left,right);
int mid = (left+right)/2;
if( mid>=x )modify(x,y,2*node,left,mid,val);
if( mid<y )modify(x,y,2*node+1,mid+1,right,val);
push_down(2*node,left,mid); push_down(2*node+1,mid+1,right);
tree[node] = ( tree[2*node] + tree[2*node+1] ) % mode;
}
long long query_path(int x,int y,int n)
{
// u,v有点多余
int v = x,u = y; long long ans=0;
while( super[v] != super[u] )
{
//让链头节点高度较低的往上跳
if( dep[super[v]] < dep[super[u]] )swap(v,u);
//查询这条链的和
ans += query( dfn[super[v]], dfn[v],1,1,n ); ans %= mode;
v = par[super[v]];
}
if( dep[v] > dep[u] )swap(v,u);
ans += query(dfn[v],dfn[u],1,1,n);
return ans % mode;
}
void modify_path(int x,int y,int z,int n)
{
int v = x,u = y;
while( super[v] != super[u] )
{
if( dep[super[v]] < dep[super[u]] )swap(v,u);
modify( dfn[super[v]], dfn[v], 1,1,n,z );
v = par[super[v]];
}
if( dep[v] > dep[u] )swap(v,u);
modify( dfn[v], dfn[u], 1,1,n,z );
return ;
}
int main()
{
int n,m,root;
scanf("%d %d %d %lld",&n,&m,&root,&mode);
for(int i=1;i<=n;i++)scanf("%d",val+i);
int x,y,op,z;
for(int i=1;i<n;i++)
{
scanf("%d %d",&x,&y);
edge[x].push_back(y);
edge[y].push_back(x);
}
dep[0] = 0;
dfs1(root,0);
tot=0;
dfs2(root,0,root);
buildTree(1,1,n);
for(int t=0;t<m;t++)
{
scanf("%d",&op);
if(op==1){
scanf("%d %d %d",&x,&y,&z);
modify_path(x,y,z,n);
}
else if(op==2){
scanf("%d %d",&x,&y);
printf("%lld\n",query_path(x,y,n));
}
else if(op==3){
scanf("%d %d",&x,&z);
modify(Start[x],End[x],1,1,n,z);
}
else if(op==4){
scanf("%d",&x);
printf("%lld\n",query(Start[x],End[x],1,1,n));
}
// printf("\n");
}
}