大概就是BZOJ5016[Snoi2017]一个简单的询问
+BZOJ3083遥远的国度
吧
换根就像遥远的国度
一样讨论
然后像一个简单的询问
一样差分
跑树上莫队即可
据说卡常严重要用fread
,fwrite
还有玄学调参才行
然而我只用普通的读入优化也没有调参就过了
#include<bits/stdc++.h>
using namespace std;
#define gc c=getchar()
#define r(x) read(x)
#define ll long long
template<typename T>
inline void read(T&x){
x=0;T k=1;char gc;
while(!isdigit(c)){if(c=='-')k=-1;gc;}
while(isdigit(c)){x=x*10+c-'0';gc;}x*=k;
}
const int N=100005;
const int M=500005;
vector<int> G[N];
int ac[N][20],dep[N],in[N],ptn[N],out[N],dfs_clock;
void dfs(int x,int f){
ac[x][0]=f;
dep[x]=dep[f]+1;
in[x]=++dfs_clock;
ptn[dfs_clock]=x;
for(int i=1;(ac[x][i]=ac[ac[x][i-1]][i-1]);++i);
for(int i=0;i<G[x].size();++i){
int v=G[x][i];
if(v!=f)dfs(v,x);
}
out[x]=dfs_clock;
}
inline int find(int x,int t){
for(int i=16;~i;--i){
if(dep[x]-dep[t]>(1<<i))x=ac[x][i];
}
// assert(ac[x][0]==t);
return x;
}
int be[N],a[N],b[N],cntl[N],cntr[N];
struct Query{
int l,r,type,id;
bool operator < (Query x) const{
return be[l]==be[x.l]?r<x.r:l<x.l;
}
};
ll Ans[M];
vector<Query> Q;
int n,m,root=1;
inline void get_query(int x,vector<pair<int,int> > &ret){
if(x==root)ret.push_back(make_pair(1,n));
else if(in[x]<in[root]&&out[root]<=out[x]){
int t=find(root,x);
ret.push_back(make_pair(1,in[t]-1));
ret.push_back(make_pair(out[t]+1,n));
}
else ret.push_back(make_pair(in[x],out[x]));
}
inline void query(int x,int y,int id){
vector<pair<int,int> >T1,T2;
get_query(x,T1);
get_query(y,T2);
for(int i=0;i<T1.size();++i){
for(int j=0;j<T2.size();++j){
Q.push_back(Query{T1[i].second,T2[j].second,1,id});
if(T1[i].first>1)Q.push_back(Query{T1[i].first-1,T2[j].second,-1,id});
if(T2[j].first>1)Q.push_back(Query{T1[i].second,T2[j].first-1,-1,id});
if((T1[i].first>1)&&(T2[j].first>1))Q.push_back(Query{T1[i].first-1,T2[j].first-1,1,id});
}
}
}
int main(){
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
r(n),r(m);
int block_size=sqrt(n);
for(int i=1;i<=n;++i){
r(a[i]);
b[i]=a[i];
be[i]=(i-1)/block_size+1;
}
sort(b+1,b+n+1);
int tot=unique(b+1,b+n+1)-b-1;
for(int i=1;i<=n;++i)a[i]=lower_bound(b+1,b+tot+1,a[i])-b;
for(int i=1,u,v;i<n;++i){
r(u),r(v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1,0);
int id=0;
for(int i=1,opt,x,y;i<=m;++i){
r(opt);
if(opt==1)r(root);
else {
r(x),r(y);
query(x,y,++id);
}
}
sort(Q.begin(),Q.end());
ll ans=0;
int l=0,r=0;
for(int i=0;i<Q.size();++i){
Query &x=Q[i];
while(l<x.l)cntl[a[ptn[++l]]]++,ans+=cntr[a[ptn[l]]];
while(l>x.l)cntl[a[ptn[l]]]--,ans-=cntr[a[ptn[l--]]];
while(r<x.r)cntr[a[ptn[++r]]]++,ans+=cntl[a[ptn[r]]];
while(r>x.r)cntr[a[ptn[r]]]--,ans-=cntl[a[ptn[r--]]];
Ans[x.id]+=x.type*ans;
}
for(int i=1;i<=id;++i)printf("%lld\n",Ans[i]);
}