题目链接
题目解法
对于子树上的操作,首先想到的就是dfn序
通过dfn序可以把一个子树压成一段连续的区间
我们考虑换根操作对一个点的子树的dfn序的影响
对于上面的树,dfn序为:1,2,4,5,6,7,3
把初始根 1 换成 5 后,2 的子树有 [ 2,4,5,6,7 ] 变成了 [ 1,2,3,4 ]
发现如果替换的根 root 在 x 的子树中时,
那么在根为 root 时 x 的子树就是刨去 root 所在的 x 的子节点的子树 的序列
如果替换的根 root 不在 x 的子树中时,
那么 x 的子树 和 根为 1 时 x 的子树一样
这样会将子树变成 1-2 段序列
我们将 dfn序 倍长一下,就可以将子树变成 1 段连续的序列
其中求 root 所在的 x 的子树可以用对每个点开一个 vector 记录儿子的 dfn 序,然后二分一下即可
也可以用无脑的倍增思想(但我不想)
问题转化成了如何求 [ l1 , r1 ],[ l2 , r2] 中相同的数的个数
这就和SNOI2017 一个简单的询问非常相似了
那题的做法可以参考题解
于是将每个询问拆成 4 个询问就可以了
时间复杂度
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair<int,int> pii;
const int N(200100),M(500100),inf(0x3f3f3f3f);
struct Query{
int id,l,r,neg;
}query[M<<2];
int n,m,mq,a[N],disc[N],clo;
int B,pos[N],cnt1[N],cnt2[N];
int top,dfn[N],rv[N],siz[N];
vector<int> vec[N];
vector<pii> start[N];
LL res,ans[M];
inline int read(){
int FF=0,RR=1;
char ch=getchar();
for(;!isdigit(ch);ch=getchar())
if(ch=='-')
RR=-1;
for(;isdigit(ch);ch=getchar())
FF=(FF<<1)+(FF<<3)+ch-48;
return FF*RR;
}
void dfs(int u,int fa){
siz[u]=1,dfn[u]=++top,rv[top]=u;
for(int i=0;i<vec[u].size();i++){
int v=vec[u][i];
if(v!=fa){
dfs(v,u);
siz[u]+=siz[v];
start[u].push_back(make_pair(dfn[v],v));
}
}
}
void calc(int &l,int &r,int x,int root){
if(x==root)
l=1,r=n;
else if(dfn[x]<=dfn[root]&&dfn[root]<=dfn[x]+siz[x]-1){
int pos=upper_bound(start[x].begin(),start[x].end(),make_pair(dfn[root],inf))-start[x].begin()-1;
int v=start[x][pos].second;
l=dfn[v]+siz[v],r=dfn[v]+n-1;
}
else
l=dfn[x],r=dfn[x]+siz[x]-1;
}
void add_edge(int l1,int r1,int l2,int r2,int x){
query[++clo]={x,r1,r2,1},query[++clo]={x,l1-1,l2-1,1};
query[++clo]={x,r1,l2-1,-1},query[++clo]={x,l1-1,r2,-1};
}
bool cmp(const Query &x,const Query &y){
if(pos[x.l]^pos[y.l])
return pos[x.l]<pos[y.l];
return pos[x.l]&1?x.r<y.r:x.r>y.r;
}
int main(){
n=read(),m=read();
for(int i=1;i<=n;i++)
a[i]=disc[i]=read();
sort(disc+1,disc+n+1);
int tot=unique(disc+1,disc+n+1)-disc-1;
for(int i=1;i<=n;i++)
a[i]=a[i+n]=lower_bound(disc+1,disc+tot+1,a[i])-disc;
for(int i=1,x,y;i<n;i++){
x=read(),y=read();
vec[x].push_back(y);
vec[y].push_back(x);
}
dfs(1,-1);
for(int i=1;i<=n;i++)
rv[i+n]=rv[i];
int root=1;
for(int i=1,op;i<=m;i++){
op=read();
if(op==1)
root=read();
else{
int x=read(),y=read();
int x_l,x_r,y_l,y_r;
calc(x_l,x_r,x,root),calc(y_l,y_r,y,root);
add_edge(x_l,x_r,y_l,y_r,++mq);
}
}
B=sqrt(n<<1);
for(int i=1;i<=n<<1;i++)
pos[i]=(i-1)/B+1;
sort(query+1,query+clo+1,cmp);
for(int k=1,i=0,j=0;k<=clo;k++){
int l=query[k].l,r=query[k].r,id=query[k].id,neg=query[k].neg;
if(!l||!r)
continue;
while(i<l) i++,res+=cnt2[a[rv[i]]],cnt1[a[rv[i]]]++;
while(i>l) res-=cnt2[a[rv[i]]],cnt1[a[rv[i]]]--,i--;
while(j<r) j++,res+=cnt1[a[rv[j]]],cnt2[a[rv[j]]]++;
while(j>r) res-=cnt1[a[rv[j]]],cnt2[a[rv[j]]]--,j--;
ans[id]+=neg*res;
}
for(int i=1;i<=mq;i++)
printf("%lld\n",ans[i]);
return 0;
}