题目传送门:23333
我是用树剖写的,剖边不剖点,就是正常的树剖。
只不过把第一个节点不计入线段树里。
只是数据有点大开数组要节省(其实一般都不会炸,但不知为何我经常MLE)。
但听说还有别的做法来节省代码量,就像我神奇的同桌的隔走廊同桌所说的那样。
似乎叫做链上求和,详情百度。
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <iostream>
#define rep(j,k,l) for (int j=k;j<=l;j++)
#define N 250005
using namespace std;
struct _233{
int l,r,sum;
} tr[N*5];
int n,m,cnt,to[N*2],ne[N*2],st[N];
int size[N],fa[N],deep[N],son[N],top[N],id[N],dfn[N];
void add(int k,int l,int p){
to[p]=l;
ne[p]=st[k];
st[k]=p;
}
void dfs1(int rt,int dad){
size[rt]=1;
fa[rt]=dad;
deep[rt]=deep[dad]+1;
son[rt]=0;int _=0;
for (int i=st[rt];i!=0;i=ne[i])
if (to[i]!=dad){
dfs1(to[i],rt);
size[rt]+=size[to[i]];
if (size[to[i]]>_){
_=size[to[i]];
son[rt]=to[i];
}
}
}
void dfs2(int rt,bool qaz){
if (qaz) top[rt]=top[fa[rt]];
else top[rt]=rt;
dfn[++cnt]=rt;
id[rt]=cnt;
if (son[rt]!=0) dfs2(son[rt],1);
for (int i=st[rt];i!=0;i=ne[i])
if (to[i]!=fa[rt]&&to[i]!=son[rt])
dfs2(to[i],0);
}
void stree(int k,int l,int r){
if (l==r){
if (l!=1) tr[k].sum=1;
return;
}
tr[k].l=++cnt;
tr[k].r=++cnt;
stree(tr[k].l,l,(l+r)/2);
stree(tr[k].r,(l+r)/2+1,r);
tr[k].sum=tr[tr[k].l].sum+tr[tr[k].r].sum;
return;
}
int sss(int k,int l,int r,int o,int p){
if (r<o||l>p) return 0;
if (o<=l&&r<=p) return tr[k].sum;
return sss(tr[k].l,l,(l+r)/2,o,p)+sss(tr[k].r,(l+r)/2+1,r,o,p);
}
int solve(int x,int y){
int ans=0;
for (;top[x]!=top[y];x=fa[top[x]]){
if (deep[top[x]]<deep[top[y]]) swap(x,y);
ans+=sss(1,1,n,id[top[x]],id[x]);
}
if (x!=y) ans+=sss(1,1,n,min(id[x],id[y])+1,max(id[x],id[y]));
return ans;
}
void change(int k,int l,int r,int o){
if (o<l||o>r) return;
if (l==r){
tr[k].sum=0;
return;
}
change(tr[k].l,l,(l+r)/2,o);
change(tr[k].r,(l+r)/2+1,r,o);
tr[k].sum=tr[tr[k].l].sum+tr[tr[k].r].sum;
}
int main(){
scanf("%d",&n);
rep(i,1,n-1){
int k,l;
scanf("%d%d",&k,&l);
add(k,l,2*i-1);add(l,k,2*i);
}
dfs1(1,0);
dfs2(1,0);
cnt=1;
stree(1,1,n);
scanf("%d",&m);
rep(i,1,n+m-1){
char ch=getchar();
while (ch<'A'||ch>'Z') ch=getchar();
if (ch=='W'){
int k;
scanf("%d",&k);
printf("%d\n",solve(1,k));
}
else{
int k,l;
scanf("%d%d",&k,&l);
if (fa[k]!=l) swap(k,l);
change(1,1,n,id[k]);
}
}
return 0;
}