找不到比这更适合当模板题的题了。
#include<cstdio>
#include<iostream>
#include<cstring>
#define mod 51061
#define ll unsigned int
using namespace std;
const int maxn=200010;
ll mt[maxn],at[maxn],fa[maxn],sz[maxn],ch[maxn][2],s[maxn],sum[maxn],val[maxn],top;
bool rev[maxn];
int n,m,x,y,u,v,c;
void cal(int x,int m,int a)
{
if(!x) return;
val[x]=(val[x]*m+a)%mod;
sum[x]=(sum[x]*m+a*sz[x])%mod;
at[x]=(at[x]*m+a)%mod;
mt[x]=(mt[x]*m)%mod;
}
inline void update(int x)
{
int l=ch[x][0],r=ch[x][1];
sum[x]=(sum[l]+sum[r]+val[x])%mod;
sz[x]=(sz[l]+sz[r]+1)%mod;
//cout<<sz[x]<<endl;
}
bool isroot(int x)
{
return ch[fa[x]][0]!=x&&ch[fa[x]][1]!=x;
}
inline void pushdown(int x)
{
if(rev[x])
{
rev[x]^=1,rev[ch[x][1]]^=1,rev[ch[x][0]]^=1;
swap(ch[x][1],ch[x][0]);
}
int m=mt[x],a=at[x];
mt[x]=1;at[x]=0;
if(m!=1||a!=0)
{
cal(ch[x][0],m,a);cal(ch[x][1],m,a);
}
}
void rotate(int x)
{
int y=fa[x],z=fa[y],l=ch[y][1]==x,r=l^1;
if(!isroot(y)) ch[z][ch[z][1]==y]=x;
fa[x]=z;fa[y]=x;fa[ch[x][r]]=y;
ch[y][l]=ch[x][r];ch[x][r]=y;
//cout<<x<<" "<<y<<endl;
update(y);update(x);
}
void splay(int x)
{
int top=0;s[++top]=x;
for(int i=x;!isroot(i);i=fa[i])
s[++top]=fa[i];
for(int i=top;i;i--) pushdown(s[i]);
while(!isroot(x))
{
int y=fa[x],z=fa[y];
if(!isroot(y))
{
if(ch[y][0]==x^ch[z][0]==y) rotate(x);
else rotate(y);
}
rotate(x);
}
}
void access(int x)
{
for(int t=0;x;t=x,x=fa[x])
{
//cout<<x<<endl;
splay(x);ch[x][1]=t;update(x);
}
}
void makeroot(int x)
{
access(x);splay(x);rev[x]^=1;
}
void link(int x,int y)
{
makeroot(x);fa[x]=y;
}
void cut(int x,int y)
{
makeroot(x);access(y);splay(y);ch[y][0]=fa[x]=0;
}
void split(int x,int y)
{
makeroot(y);access(x);splay(x);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
val[i]=sum[i]=mt[i]=sz[i]=1;
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
link(x,y);
}
char op[5];
while(m--)
{
scanf("%s",op);
scanf("%d%d",&u,&v);
if(op[0]=='+')
{
scanf("%d",&c);
split(u,v);cal(u,1,c);
}
if(op[0]=='-')
{
cut(u,v);
scanf("%d%d",&u,&v);
link(u,v);
}
if(op[0]=='*')
{
scanf("%d",&c);
split(u,v);cal(u,c,0);
}
if(op[0]=='/')
{
split(u,v);
printf("%d\n",sum[u]);
}
}
return 0;
}