Title
SP10707 COT2 - Count on a tree II
Solution
注意lca别打错了
注意欧拉序的区间别错了
注意不带修的莫队块的大小设置为sqrt(n),否则为pow(n,2.0/3.0)
先把节点离散化一下。
我们可以求树上的欧拉序,对于任意两个点上的路径,就是从
i
i
i到
j
j
j的只出现了一个数的节点。假如是它们的
l
c
a
lca
lca的话,就不会出现在序列中,所以要特殊处理。
用
f
i
r
s
t
first
first标记第一次出现的位置,
用
l
a
s
t
last
last标记第二次出现的位置。
如果两个点在同一条链上,那么我们的莫队的区间就是
f
i
r
s
t
[
l
]
first[l]
first[l]到
f
i
r
s
t
[
r
]
first[r]
first[r],否则是
l
a
s
t
[
l
]
last[l]
last[l]到
f
i
r
s
t
[
r
]
first[r]
first[r],还有标记
l
c
a
lca
lca,因为
f
i
r
s
t
[
l
]
first[l]
first[l]到
l
a
s
t
[
l
]
last[l]
last[l]这段区间对答案没有贡献。
莫队就是正常的莫队了。
具体实现中,可以用
v
i
s
vis
vis标记每个点是否被访问,异或处理一下。
Code
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<cmath>
#define rep(i,x,y) for(int i=x;i<=y;i++)
using namespace std;
const int N=2e5+5;
int read(){
int p=0; char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) p=(p<<3)+(p<<1)+c-48,c=getchar();
return p;
}
struct node1{
int l,r,lca,id;
}q[N];
struct node{
int y,next;
}a[N];
int n,Q,m,size,cnt[N],f[N][22],t=20,c[N],b[N],head[N],tot,ans[N],val,bg[N],dep[N],first[N],last[N],num,d[N];
bool vis[N];
void add(int x,int y){
a[++tot]=(node){y,head[x]}; head[x]=tot;
}
bool cmp(node1 x,node1 y){
return (bg[x.l]^bg[y.l])?bg[x.l]<bg[y.l]:((bg[x.l]&1)?x.r<y.r:x.r>y.r);
}
void dfs(int x){
d[++num]=x; first[x]=num;
for(int i=head[x];i;i=a[i].next){
int y=a[i].y;
if (y!=f[x][0]){
dep[y]=dep[x]+1;
f[y][0]=x;
for(int i=1;i<=t;i++) f[y][i]=f[f[y][i-1]][i-1];
dfs(y);
}
}
d[++num]=x; last[x]=num;
}
int lca(int x,int y){
if (dep[x]<dep[y]) swap(x,y);
for(int i=t;i>=0;i--) if (dep[f[x][i]]>=dep[y]/*&&f[x][i]!=0*/) x=f[x][i]; //"1" -> "0"
if (x==y) return x;
for(int i=t;i>=0;i--) if (f[y][i]!=f[x][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
void work(int x){
vis[x]?val-=!--cnt[c[x]]:val+=!cnt[c[x]]++;
vis[x]^=1;
}
int main(){
n=read(); Q=read();
rep(i,1,n) b[i]=c[i]=read();
sort(b+1,b+n+1);
m=unique(b+1,b+n+1)-b-1;
rep(i,1,n) c[i]=lower_bound(b+1,b+m+1,c[i])-b; //"-b-1" -> "-b" ",a[i]"->"c[i]"
rep(i,1,n-1) {
int x=read(),y=read();
add(x,y),add(y,x);
}
dep[1]=1;
dfs(1);
size=sqrt(num);
rep(i,1,num) bg[i]=(i-1)/size+1;
rep(i,1,Q){
int l=read(),r=read(),tlca=lca(l,r);
if (first[l]>first[r]) swap(l,r);
if (tlca==l){
q[i].l=first[l];
q[i].r=first[r];
} else {
q[i].l=last[l];//"last[r]"->"last[l]"
q[i].r=first[r];
q[i].lca=tlca;
}
q[i].id=i;
}
sort(q+1,q+Q+1,cmp);
int l=1,r=0;
rep(i,1,Q){
int ql=q[i].l,qr=q[i].r,tlca=q[i].lca;
while (l>ql) work(d[--l]);
while (r<qr) work(d[++r]);
while (l<ql) work(d[l++]);
while (r>qr) work(d[r--]);
if (tlca) work(tlca);
ans[q[i].id]=val;
if (tlca) work(tlca);
}
rep(i,1,Q) printf("%d\n",ans[i]);
return 0;
}