题意:给你一棵N个节点的树,每个节点有一个权值,对于M个询问(u,v),你需要回答u和v这两个节点间有多少种不同的点权。
思路:1点权比较大需要离散化。2对这棵树求其dfs序分块,3对m组询问的L和R求lca,4莫队,lca需要特殊处理。
#include <bits/stdc++.h>
using namespace std;
const int maxn=40050;
const int maxm=1e5+10;
int n,m,t;
struct node{
int id,l,r,extra;
}s[maxm];
int block[maxn*2];
bool cmp(const node &a,const node &b)
{
if(block[a.l]!=block[b.l]) return a.l<b.l;
return a.r<b.r;
}
int in[maxn],out[maxn],p[maxn*2],dep[maxn],fa[maxn][22];
int a[maxn],c[maxn];
vector<int>g[maxn];
int tim=0;
void dfs(int u,int pre)
{
in[u]=++tim;
p[tim]=u;
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if(g[u][i]==pre) continue;
dep[v]=dep[u]+1;
fa[v][0]=u;
dfs(v,u);
}
out[u]=++tim;
p[tim]=u;
}
int lca(int x,int y)
{
if (dep[x]<dep[y]) swap(x,y);
for(int i=20;i>=0;i--)
if (dep[fa[x][i]]>=dep[y])
x=fa[x][i];
if (x==y) return x;
for(int i=20;i>=0;i--)
if (fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int vis[maxn];
int sum;
void expand(int x)
{
vis[p[x]]^=1;
if (vis[p[x]])
{
c[a[p[x]]]++;
if (c[a[p[x]]]==1) sum++;
}
else
{
c[a[p[x]]]--;
if (c[a[p[x]]]==0) sum--;
}
}
int ans[maxm];
void Modui()
{
int l=1,r=0;
sum=0;
for(int i=1;i<=m;i++)
{
while (s[i].l<l) expand(--l);
while (s[i].r>r) expand(++r);
while (s[i].l>l) expand(l++);
while (s[i].r<r) expand(r--);
if (s[i].extra&&!c[a[s[i].extra]])
ans[s[i].id]=sum+1;
else ans[s[i].id]=sum;
}
}
struct BB{
int id,x;
}b[maxn];
bool cb(const BB&a,const BB&b)
{
return a.x<b.x;
}
int main()
{
while(~scanf("%d%d",&n,&m))
{
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
b[i].x=a[i];
b[i].id=i;
}
sort(b+1,b+n+1,cb);
int o=0;
for(int i=1;i<=n;i++)
{
if(i==1||b[i].x!=b[i-1].x) o++;
a[b[i].id]=o;
}
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
g[x].push_back(y);
g[y].push_back(x);
}
dep[1]=1;
tim=0;
dfs(1,0);
for(int j=1;j<=20;j++)
for(int i=1;i<=n;i++)
fa[i][j]=fa[fa[i][j-1]][j-1];
t=(int)sqrt(2.0*n);
for(int i=1;i<=(n<<1);i++)
block[i]=i/t;
for(int i=1;i<=m;i++)
{
int x,y,gg;
scanf("%d%d",&x,&y);
s[i].id=i;
gg=lca(x,y);
if(gg==x||gg==y)
{
s[i].l=in[x];
s[i].r=in[y];
s[i].extra=0;
}
else
{
s[i].extra=gg;
if(out[x]<=in[y])
{
s[i].l=out[x];
s[i].r=in[y];
}
else
{
s[i].l=out[y];
s[i].r=in[x];
}
}
if(s[i].l>s[i].r)
swap(s[i].l,s[i].r);
}
sort(s+1,s+m+1,cmp);
Modui();
for(int i=1;i<=m;i++)
printf("%d\n",ans[i]);
}
return 0;
}