https://vjudge.net/problem/SPOJ-COT2
题目大意:给一棵
n
n
n个节点的数,每个节点有一种颜色,
m
m
m个询问
(
u
,
v
)
(u,v)
(u,v),你需要输出
u
、
v
u、v
u、v这条链上的不同的颜色数目。
思路:如果这道题是在区间内询问的,相信大家都能看出来是莫队。在树上怎么搞呢?树上莫队。通过欧拉序把链转换成区间,然后就可以做了。树上莫队的思想这里不想重复了,不懂的可以看下面这篇博文学习一下:
https://blog.csdn.net/xiji333/article/details/105179574
代码中也有注释,一起食用更佳。
#include<bits/stdc++.h>
#define INF 0x3f3f3f3f
using namespace std;
typedef long long ll;
const int maxn=1e5+5;
int n,m,res,tot,num;
int a[maxn],b[maxn],pos[maxn],cnt[maxn];//莫队需要用到的
int oula[maxn],in[maxn],out[maxn]; //欧拉序需要用到的
int head[maxn],deep[maxn],f[maxn][17],bs[17];//LCA需要用到的
int ans[maxn];//记录结果
bool vis[maxn];//判断一个点是否在路径上
struct node
{
int l,r,idx,lca=0;
bool operator <(const node &a)const
{
if(pos[l]==pos[a.l])
return r<a.r;
return pos[l]<pos[a.l];
}
}q[maxn];
struct Edge
{
int to,nxt;
}edge[maxn];
inline void addedge(int u,int v)
{
edge[++tot].to=v,edge[tot].nxt=head[u],head[u]=tot;
}
void dfs(int u,int fa)//LCA的预处理
{
oula[++num]=u,in[u]=num;
deep[u]=deep[fa]+1;
f[u][0]=fa;
for(int i=1;i<17;i++)
f[u][i]=f[f[u][i-1]][i-1];
int v;
for(int i=head[u];i;i=edge[i].nxt)
{
v=edge[i].to;
if(v!=fa)
dfs(v,u);
}
oula[++num]=u,out[u]=num;
}
inline int skip(int x,int level)//求LCA要用到
{
for(int i=0;i<17;i++)
if(level&bs[i])
x=f[x][i];
return x;
}
inline int LCA(int u,int v)//求LCA
{
if(deep[u]<deep[v])
swap(u,v);
u=skip(u,deep[u]-deep[v]);
if(u==v)
return u;
for(int i=16;i>=0;i--)
if(f[u][i]!=f[v][i])
u=f[u][i],v=f[v][i];
return f[u][0];
}
inline void work(int idx)//传进来的是节点编号
{
vis[idx]^=1;
if(vis[idx])//第一次出现的点 计算贡献
{
if(++cnt[a[idx]]==1)
++res;
}
else//第二次出现的点 说明这个点不在路径上
{
if(--cnt[a[idx]]==0)
--res;
}
}
int main()
{
for(int i=0;i<17;i++)
bs[i]=1<<i;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]),b[i]=a[i];
sort(b+1,b+1+n);
int nn=unique(b+1,b+1+n)-b-1;//因为颜色的范围比较大 所以需要离散化
for(int i=1;i<=n;i++)
a[i]=lower_bound(b+1,b+1+nn,a[i])-b;
int u,v;
for(int i=1;i<n;i++) //连边
{
scanf("%d%d",&u,&v);
addedge(u,v),addedge(v,u);
}
dfs(1,0); //dfs做LCA的预处理 同时把欧拉序列求出来
int dis=sqrt(num); //区间为[1,num]
for(int i=1;i<=num;i++)
pos[i]=(i-1)/dis+1;//莫队排序要用到
int lca;
for(int i=0;i<m;i++)
{
scanf("%d%d",&q[i].l,&q[i].r);
lca=LCA(q[i].l,q[i].r);
if(in[q[i].l]>in[q[i].r])//保证in[l]>=in[r] 这样计算得到的区间才合理哦
swap(q[i].l,q[i].r);
if(q[i].l==lca) //简单的情况
q[i].l=in[q[i].l],q[i].r=in[q[i].r];
else //复杂的情况
q[i].l=out[q[i].l],q[i].r=in[q[i].r],q[i].lca=lca;
q[i].idx=i;
}
sort(q,q+m);//询问离线
int l=1,r=0;
for(int i=0;i<m;i++)
{
while(l<q[i].l)
work(oula[l++]);
while(l>q[i].l)
work(oula[--l]);
while(r<q[i].r)
work(oula[++r]);
while(r>q[i].r)
work(oula[r--]);
if(q[i].lca) //特判掉LCA
work(q[i].lca);
ans[q[i].idx]=res;
if(q[i].lca) //要把LCA的贡献消除
work(q[i].lca);
}
for(int i=0;i<m;i++)
printf("%d\n",ans[i]);
return 0;
}