题意:给一棵带点权的树,多次询问两点间路径上的不同权值数
学习了一下莫队上树(雾
先求出栈入栈序$p_{1\cdots 2n}$,记$st_x$为$x$在$p$中第一次出现的位置,$ed_x$为$x$在$p$中最后一次出现的位置
对于一个询问$(x,y)$,先令$st_x\lt st_y$,求出其$lca$,若$x=lca$,则询问$[st_x,st_y]$,否则询问$[ed_x,st_y]$还有$lca$(因为$[ed_x,st_y]$不包含$lca$)
于是我们成功地把问题转化到序列上,普通地莫队即可
不应处理出现两次的数(因为它入栈一次,出栈一次,在路径之外)
#include<stdio.h>
#include<algorithm>
#include<map>
#include<math.h>
using namespace std;
struct ask{
int l,r,lca,sid,qid;
}q[100010];
int to[80010],nex[80010],h[40010],dep[40010],fa[40010][16],v[40010],p[80010],st[40010],ed[40010],ti[40010],ex[40010],ans[100010],M;
map<int,int>mp;
map<int,int>::iterator it;
bool cmp(ask a,ask b){
if(a.sid==b.sid)return a.r<b.r;
return a.sid<b.sid;
}
void add(int a,int b){
M++;
to[M]=b;
nex[M]=h[a];
h[a]=M;
}
void dfs(int x){
st[x]=++M;
p[M]=x;
for(int i=h[x];i;i=nex[i]){
if(to[i]!=fa[x][0]){
fa[to[i]][0]=x;
dep[to[i]]=dep[x]+1;
dfs(to[i]);
}
}
ed[x]=++M;
p[M]=x;
}
int lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
int i;
for(i=15;i>=0;i--){
if(dep[fa[x][i]]>=dep[y])x=fa[x][i];
}
if(x==y)return x;
for(i=15;i>=0;i--){
if(fa[x][i]!=fa[y][i]){
x=fa[x][i];
y=fa[y][i];
}
}
return fa[x][0];
}
int sum;
void add(int);
void del(int x){
if(ex[x]==0)return add(x);
ti[v[x]]--;
if(ti[v[x]]==0)sum--;
ex[x]=0;
}
void add(int x){
if(ex[x])return del(x);
if(ti[v[x]]==0)sum++;
ti[v[x]]++;
ex[x]=1;
}
int main(){
int n,m,i,j,x,y,l,r,sq;
scanf("%d%d",&n,&m);
sq=sqrt(n);
for(i=1;i<=n;i++){
scanf("%d",v+i);
mp[v[i]]=1;
}
for(i=1,it=mp.begin();it!=mp.end();it++,i++)it->second=i;
for(i=1;i<=n;i++)v[i]=mp[v[i]];
for(i=1;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
M=0;
dep[1]=1;
dfs(1);
for(j=1;j<16;j++){
for(i=1;i<=n;i++)fa[i][j]=fa[fa[i][j-1]][j-1];
}
for(i=1;i<=m;i++){
scanf("%d%d",&x,&y);
if(st[x]>st[y])swap(x,y);
j=lca(x,y);
q[i].r=st[y];
if(x==j)
q[i].l=st[x];
else{
q[i].l=ed[x];
q[i].lca=j;
}
q[i].sid=q[i].l/sq;
q[i].qid=i;
}
sort(q+1,q+m+1,cmp);
l=r=1;
add(1);
for(i=1;i<=m;i++){
while(l>q[i].l){
l--;
add(p[l]);
}
while(r<q[i].r){
r++;
add(p[r]);
}
while(l<q[i].l){
del(p[l]);
l++;
}
while(r>q[i].r){
del(p[r]);
r--;
}
if(q[i].lca)add(q[i].lca);
ans[q[i].qid]=sum;
if(q[i].lca)del(q[i].lca);
}
for(i=1;i<=m;i++)printf("%d\n",ans[i]);
}