传送门
lct经典题。
貌似就是ahoi维护序列放到了lct上。
就直接像线段树那样维护就行了。
不过这个时候区间add时不再是:
s
u
m
+
=
(
r
−
l
+
1
)
∗
v
sum+=(r-l+1)*v
sum+=(r−l+1)∗v了。
而应该是:
s
u
m
+
=
s
i
z
e
∗
v
sum+=size*v
sum+=size∗v,因为动态树是没有固定区间的。
然后注意标记的下放以及处处取模什么的应该就能AC了吧。
代码:
#include<bits/stdc++.h>
#define N 100005
#define mod 51061
using namespace std;
inline int read(){
int ans=0;
char ch=getchar();
while(!isdigit(ch))ch=getchar();
while(isdigit(ch))ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
return ans;
}
int rev[N],fa[N],son[N][2],siz[N],stk[N],top,n,m;
unsigned int val[N],add[N],mul[N],sum[N];
inline int which(int x){return x==son[fa[x]][1];}
inline bool isroot(int x){return !fa[x]||(x!=son[fa[x]][0]&&x!=son[fa[x]][1]);}
inline void pushup(int p){sum[p]=(sum[son[p][0]]+sum[son[p][1]]+val[p])%mod,siz[p]=(siz[son[p][0]]+siz[son[p][1]]+1)%mod;}
inline void pushnow(int p,int v1,int v2){
if(!p)return;
sum[p]=(sum[p]*v1%mod+siz[p]*v2%mod)%mod;
val[p]=(val[p]*v1%mod+v2)%mod;
add[p]=(add[p]*v1%mod+v2)%mod;
mul[p]=mul[p]*v1%mod;
}
inline void pushdown(int p){
if(rev[p]){
swap(son[p][0],son[p][1]),rev[p]^=1;
if(son[p][0])rev[son[p][0]]^=1;
if(son[p][1])rev[son[p][1]]^=1;
}
if(mul[p]!=1||add[p])pushnow(son[p][0],mul[p],add[p]),pushnow(son[p][1],mul[p],add[p]),mul[p]=1,add[p]=0;
}
inline void rotate(int x){
int y=fa[x],z=fa[y],t=which(x);
if(z&&!isroot(y))son[z][which(y)]=x;
fa[y]=x,fa[x]=z,son[y][t]=son[x][t^1],son[x][t^1]=y;
if(son[y][t])fa[son[y][t]]=y;
pushup(y),pushup(x);
}
inline void splay(int x){
stk[top=1]=x;
for(int i=x;!isroot(i);i=fa[i])stk[++top]=fa[i];
while(top)pushdown(stk[top--]);
while(!isroot(x)){if(!isroot(fa[x]))rotate(which(fa[x])==which(x)?fa[x]:x);rotate(x);}
}
inline void access(int x){for(int y=0;x;x=fa[y=x])splay(x),son[x][1]=y,pushup(x);}
inline void makeroot(int x){access(x),splay(x),rev[x]^=1,pushdown(x);}
inline void link(int x,int y){makeroot(x),fa[x]=y;}
inline void cut(int x,int y){makeroot(x),access(y),splay(y),son[y][0]=fa[x]=0,pushdown(x);}
inline void split(int x,int y){makeroot(y),access(x),splay(x);}
int main(){
n=read(),m=read();
for(int i=1;i<=n;++i)siz[i]=sum[i]=val[i]=mul[i]=1;
for(int i=1,u,v;i<n;++i)u=read(),v=read(),link(u,v);
while(m--){
char s[5];
int w,x,y;
scanf("%s",s),x=read(),y=read();
if(s[0]=='*')w=read(),split(x,y),pushnow(x,w,0);
if(s[0]=='+')w=read(),split(x,y),pushnow(x,1,w);
if(s[0]=='-')cut(x,y),x=read(),y=read(),link(x,y);
if(s[0]=='/')split(x,y),printf("%d\n",sum[x]);
}
return 0;
}