挺厉害的一道大数据结构题.
由于 LCT 是维护树的形态的,所以说不支持翻转操作.
而在维护序列时 splay 是支持区间翻转的.
所以,我们对于 LCT 中每一个重链都维护一个 splay(这个不同于 LCT 中的 splay)
由于重链是一个序列,所以是支持序列的区间翻转的.
那么我们的翻转,链加和,链求和的操作就都在这个重链对应的 splay 上进行.
然后这里一定要注意:我们在 LCT 中维护 LCT 中每个点对应到 splay 上的编号,只有 LCT 中的 splay 的根节点对应的是正确的编号.
#include <cstdio>
#include <string>
#include <vector>
#include <cstring>
#include <algorithm>
#define N 50007
#define ll long long
using namespace std;
namespace IO
{
void setIO(string s)
{
string in=s+".in";
string out=s+".out";
freopen(in.c_str(),"r",stdin);
freopen(out.c_str(),"w",stdout);
}
};
namespace Splay
{
#define lson s[x].ch[0]
#define rson s[x].ch[1]
struct node
{
int ch[2],f,rev,size;
ll add,val,sum,Min,Max;
}s[N];
int sta[N];
int get(int x) { return s[s[x].f].ch[1]==x; }
void mark_rev(int x) { s[x].rev^=1,swap(lson,rson);}
void mark_add(int x,ll v)
{
s[x].add+=v;
s[x].sum+=1ll*s[x].size*v;
s[x].Min+=v,s[x].Max+=v,s[x].val+=v;
}
void pushup(int x)
{
s[x].sum=s[x].Min=s[x].Max=s[x].val;
s[x].size=s[lson].size+s[rson].size+1;
if(lson)
{
s[x].sum+=s[lson].sum;
s[x].Min=min(s[x].Min,s[lson].Min);
s[x].Max=max(s[x].Max,s[lson].Max);
}
if(rson)
{
s[x].sum+=s[rson].sum;
s[x].Min=min(s[x].Min,s[rson].Min);
s[x].Max=max(s[x].Max,s[rson].Max);
}
}
void pushdown(int x)
{
if(s[x].rev)
{
if(lson) mark_rev(lson);
if(rson) mark_rev(rson);
s[x].rev=0;
}
if(s[x].add)
{
if(lson) mark_add(lson,s[x].add);
if(rson) mark_add(rson,s[x].add);
s[x].add=0;
}
}
void rotate(int x)
{
int old=s[x].f,fold=s[old].f,which=get(x);
s[old].ch[which]=s[x].ch[which^1];
if(s[old].ch[which]) s[s[old].ch[which]].f=old;
s[x].ch[which^1]=old,s[old].f=x,s[x].f=fold;
if(fold) s[fold].ch[s[fold].ch[1]==old]=x;
pushup(old),pushup(x);
}
void splay(int x)
{
int fa,v=0,tmp=x;
for(;tmp;tmp=s[tmp].f) sta[++v]=tmp;
for(;v;--v) pushdown(sta[v]);
for(;fa=s[x].f;rotate(x))
if(s[fa].f)
rotate(get(fa)==get(x)?fa:x);
}
int get_kth(int x,int kth)
{
pushdown(x);
if(kth<=s[lson].size) return get_kth(lson,kth);
else if(s[lson].size+1==kth) return x;
else return get_kth(rson,kth-s[lson].size-1);
}
int findrt(int x)
{
while(s[x].f) { x=s[x].f; }
return x;
}
#undef lson
#undef rson
};
#define ls s[x].ch[0]
#define rs s[x].ch[1]
struct node
{
int ch[2],f,rev,size;
}s[N];
int sta[N],rt[N];
int get(int x) { return s[s[x].f].ch[1]==x; }
int Isr(int x) { return s[s[x].f].ch[0]!=x&&s[s[x].f].ch[1]!=x; }
void mark(int x) { swap(ls,rs), s[x].rev^=1; }
void pushup(int x) { s[x].size=s[ls].size+s[rs].size+1; }
void pushdown(int x)
{
if(s[x].rev)
{
s[x].rev=0;
if(ls) mark(ls);
if(rs) mark(rs);
}
}
void rotate(int x)
{
int old=s[x].f,fold=s[old].f,which=get(x);
if(!Isr(old)) s[fold].ch[s[fold].ch[1]==old]=x;
s[old].ch[which]=s[x].ch[which^1];
if(s[old].ch[which]) s[s[old].ch[which]].f=old;
s[x].ch[which^1]=old,s[old].f=x,s[x].f=fold;
pushup(old),pushup(x);
}
void splay(int x)
{
int u=x,v=0,fa;
for(sta[++v]=u;!Isr(u);u=s[u].f) sta[++v]=s[u].f;
rt[x]=rt[u];
for(;v;--v) pushdown(sta[v]);
for(u=s[u].f;(fa=s[x].f)!=u;rotate(x))
if(s[fa].f!=u)
rotate(get(fa)==get(x)?fa:x);
}
void Access(int x)
{
for(int y=0;x;y=x,x=s[x].f)
{
splay(x);
if(rs) // cut
{
rt[x]=Splay::get_kth(rt[x],s[ls].size+1);
Splay::splay(rt[x]);
rt[rs]=Splay::s[rt[x]].ch[1];
Splay::s[rt[rs]].f=0;
Splay::s[rt[x]].ch[1]=0;
Splay::pushup(rt[x]);
}
if(y) // link
{
rt[x]=Splay::get_kth(rt[x],s[ls].size+1);
Splay::splay(rt[x]);
Splay::s[rt[x]].ch[1]=rt[y];
Splay::s[rt[y]].f=rt[x];
Splay::pushup(rt[x]);
}
rs=y;
pushup(x);
}
}
void makeroot(int x)
{
Access(x),splay(x),mark(x),Splay::mark_rev(rt[x]);
}
void split(int x,int y)
{
makeroot(x),Access(y),splay(y);
}
#undef ls
#undef rs
int edges;
int hd[N],to[N<<1],nex[N<<1];
void add(int u,int v)
{
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
}
void dfs(int u,int ff)
{
s[u].f=ff;
rt[u]=u;
s[u].size=1;
Splay::s[rt[u]].size=1;
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(v==ff) continue;
dfs(v,u);
}
}
int main()
{
// IO::setIO("input");
int i,j,n,m,R;
scanf("%d%d%d",&n,&m,&R);
for(i=1;i<n;++i)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y),add(y,x);
}
dfs(1,0);
for(i=1;i<=m;++i)
{
char op[10];
int x,y,z;
scanf("%s",op+1);
if(op[3]=='c')
{
scanf("%d%d%d",&x,&y,&z);
split(x,y);
Splay::mark_add(rt[y],(ll)z);
}
if(op[3]=='m')
{
scanf("%d%d",&x,&y);
split(x,y);
printf("%lld\n",Splay::s[rt[y]].sum);
}
if(op[3]=='j')
{
scanf("%d%d",&x,&y);
split(x,y);
printf("%lld\n",Splay::s[rt[y]].Max);
}
if(op[3]=='n')
{
scanf("%d%d",&x,&y);
split(x,y);
printf("%lld\n",Splay::s[rt[y]].Min);
}
if(op[3]=='v')
{
scanf("%d%d",&x,&y);
split(x,y);
Splay::mark_rev(rt[y]);
}
}
return 0;
}