传送门~
LCT模板。
代码:
#include<algorithm>
#include<cmath>
#include<cstring>
#include<iostream>
#include<string>
#include<cstdio>
#include<cstdlib>
#define int long long
using namespace std;
const int mod=51061;
struct node{
int siz,val,sum,add_mark,times_mark;
bool rev_mark;
node *ch[2],*fa;
int son(){
if(fa->ch[0]==this) return 0;
if(fa->ch[1]==this) return 1;
return -1;
}
node();
void pushdown();
void maintain();
void rev();
void add(int);
void times(int);
}*null=new node(),*root[100005];
node:: node(){
siz=null ? 1 : 0;
val=sum=siz;
ch[0]=ch[1]=fa=null;
add_mark=0;
times_mark=1;
}
void node:: pushdown(){
if(this==null) return ;
if(times_mark!=1){
ch[0]->times(times_mark);
ch[1]->times(times_mark);
times_mark=1;
}
if(add_mark){
ch[0]->add(add_mark);
ch[1]->add(add_mark);
add_mark=0;
}
if(rev_mark){
ch[0]->rev();
ch[1]->rev();
rev_mark=false;
}
}
void node:: maintain(){
if(this==null) return ;
siz=ch[0]->siz+ch[1]->siz+1;
sum=(ch[0]->sum+ch[1]->sum+val)%mod;
}
void node:: add(int x){
if(this==null) return ;
val=(val+x)%mod;
add_mark=(add_mark+x)%mod;
sum=(sum+x*siz)%mod;
}
void node:: times(int x){
if(this==null) return ;
times_mark=times_mark*x%mod;
val=val*x%mod;
sum=sum*x%mod;
add_mark=add_mark*x%mod;
}
void node:: rev(){
if(this==null) return ;
rev_mark=!rev_mark;
swap(ch[0],ch[1]);
}
void To_pushdown(node* p){
if(~p->son()) To_pushdown(p->fa);
p->pushdown();
}
void Rotate(node* p,bool f){
node* t=p->ch[f^1];
p->ch[f^1]=t->ch[f];
if(t->ch[f]!=null) t->ch[f]->fa=p;
t->ch[f]=p;
p->maintain();t->maintain();
if(~p->son()) p->fa->ch[p->son()]=t;
t->fa=p->fa;p->fa=t;
}
void splay(node* p){
To_pushdown(p);
while(~p->son()){
int dir=p->son();
if(dir==p->fa->son()) Rotate(p->fa->fa,dir^1);
Rotate(p->fa,dir^1);
}
}
void Access(node* p){
node *t=null;
while(p!=null){
splay(p);
p->ch[1]=t;p->maintain();
t=p;p=p->fa;
}
}
void Move_to_root(node *p){
Access(p);splay(p);
p->rev();
}
void Link(node* x,node* y){
Move_to_root(x);
x->fa=y;
}
void Cut(node* x,node* y){
Move_to_root(x);
Access(y);splay(y);
y->ch[0]=null;y->maintain();
x->fa=null;
}
#undef int
int main(){
#define int long long
int n,m;
scanf("%lld%lld",&n,&m);
for(int i=1;i<=n;i++) root[i]=new node();
for(int i=1;i<n;i++){
int x,y;
scanf("%lld%lld",&x,&y);
Link(root[x],root[y]);
}
while(m--){
char opt[2];
scanf("%s",opt);
if(opt[0]=='+'){
int x,y,z;
scanf("%lld%lld%lld",&x,&y,&z);
Move_to_root(root[x]);
Access(root[y]);splay(root[y]);
root[y]->add(z);
}
else if(opt[0]=='-'){
int x,y;
scanf("%lld%lld",&x,&y);
Cut(root[x],root[y]);
scanf("%lld%lld",&x,&y);
Link(root[x],root[y]);
}
else if(opt[0]=='*'){
int x,y,z;
scanf("%lld%lld%lld",&x,&y,&z);
Move_to_root(root[x]);
Access(root[y]);splay(root[y]);
root[y]->times(z);
}
else{
int x,y;
scanf("%lld%lld",&x,&y);
Move_to_root(root[x]);
Access(root[y]);splay(root[y]);
printf("%lld\n",root[y]->sum%mod);
}
}
return 0;
}