题目链接:2631:tree
裸LCT模板题,注意清理标记的顺序:先乘法后加法,清理乘法标记的时候子节点的加法标记什么的也要乘以这个数
#include<cstdio>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define ui unsigned int
using namespace std;
const int maxn=110010;
const int mod=51061;
int n,q,sta[maxn];
struct Nodes{
int fa,c[2],rev,size;
ui mul,add,val,sum;
};
struct splay_tree{
Nodes t[maxn];
void init(int x){
t[x].val=1;t[x].mul=1;t[x].add=0;
t[x].sum=1;t[x].size=1;
}
void update(int x,ui m,ui a){
if (m==1&&a==0) return;
t[x].val=(t[x].val*m+a)%mod;
t[x].sum=(t[x].sum*m+a*t[x].size)%mod;
t[x].add=(t[x].add*m+a)%mod;
t[x].mul=(t[x].mul*m)%mod;
}
void push_up(int x){
int lson=t[x].c[0],rson=t[x].c[1];
t[x].sum=(t[lson].sum+t[rson].sum+t[x].val)%mod;
t[x].size=t[lson].size+t[rson].size+1;
}
void push_down(int x){
if (t[x].rev){
swap(t[x].c[0],t[x].c[1]);
t[t[x].c[0]].rev^=1;
t[t[x].c[1]].rev^=1; t[x].rev=0;
}
ui m=t[x].mul,a=t[x].add;
t[x].mul=1; t[x].add=0;
if (t[x].c[0]) update(t[x].c[0],m,a);
if (t[x].c[1]) update(t[x].c[1],m,a);
}
void rotate(int p,int x){
int mark= p==t[x].c[1];
int y=t[p].c[mark^1],z=t[x].fa;
if (x==t[z].c[0]) t[z].c[0]=p;
else if (x==t[z].c[1]) t[z].c[1]=p;
if (y) t[y].fa=x; t[x].fa=p; t[x].c[mark]=y;
t[x].fa=p; t[p].c[mark^1]=x; t[p].fa=z;
push_up(x);
}
bool isroot(int x){
return t[t[x].fa].c[0]!=x&&t[t[x].fa].c[1]!=x;
}
void pre(int x){
int top=0;
while (!isroot(x)) sta[++top]=x,x=t[x].fa;
sta[++top]=x;
while (top) push_down(sta[top]),top--;
}
void splay(int p){
pre(p);
while (!isroot(p)){
int x=t[p].fa,y=t[x].fa;
if (isroot(x)) rotate(p,x);
else if (p==t[x].c[0]^x==t[y].c[0]) rotate(p,x),rotate(p,y);
else rotate(x,y),rotate(p,x);
}push_up(p);
}
void access(int x){
for (int v=0;x;x=t[x].fa){
splay(x); t[x].c[1]=v;
push_up(x); t[v].fa=x; v=x;
}
}
void rever(int x){
access(x); splay(x); t[x].rev^=1;
}
void cut(int x,int y){
rever(x); access(y); splay(y);
t[x].fa=t[y].c[0]=0;
}
void link(int x,int y){
rever(x); t[x].fa=y; access(x);
}
void getans(int x,int y){
rever(x); access(y); splay(y);
printf("%d\n",(t[y].sum%mod+mod)%mod);
}
void modifyadd(int x,int y,ui z){
rever(x); access(y); splay(y);
update(y,1,z);
}
void modifymul(int x,int y,ui z){
rever(x); access(y); splay(y);
update(y,z,0);
}
}lct;
int main(){
scanf("%d%d",&n,&q);
for (int i=1;i<=n;++i) lct.init(i);
for (int i=1;i<n;++i){
int x,y; scanf("%d%d",&x,&y);
lct.link(x,y);
}
for (int i=1;i<=q;++i){
char s[5]; scanf("%s",s);
if (s[0]=='+'){
int x,y; ui z;
scanf("%d%d%d",&x,&y,&z);
lct.modifyadd(x,y,z);
}else if (s[0]=='-'){
int x,y,u,v;
scanf("%d%d%d%d",&x,&y,&u,&v);
lct.cut(x,y); lct.link(u,v);
}else if (s[0]=='*'){
int x,y; ui z;
scanf("%d%d%d",&x,&y,&z);
lct.modifymul(x,y,z);
}else if (s[0]=='/'){
int x,y;
scanf("%d%d",&x,&y);
lct.getans(x,y);
}
}
}