题目链接:
对于每个点,它的答案最大就是与它距离最远的点的距离。
而如果与它距离为$x$的点有大于等于两个,那么与它距离小于等于$x$的点都不会被计入答案。
所以我们需要找到对于每个点$u$距离它最远的点及最小的距离$x$满足距离$u$的距离大于等于$x$的点都只有一个。
那么怎么找距离每个点最远的点?
这个点自然就是树的直径的一个端点了!
我们将树的直径先找到,然后讨论一下对于每个点,有哪些点可能会被计入答案:
如图所示,我们以点$x$为例,假设它距离直径两端点中的$S$较近($y$为$x$距离直径上最近的点),设$dis$代表两点距离:
对于$y$左边点所有点,显然$S$与$y$的距离最远,但$dis(S,y)<dis(T,y)$,所以$y$左边的所有点都不会被计入答案。
对于在$x$子树中的点,他们与$x$的距离要小于$dis(y,T)$,也就小于$dis(x,T)$,所以不会被计入答案。
对于在$y$子树中但不在$x$子树中的点(例如$b$),因为$dis(y,b)\le dis(y,S)$,所以$dis(b,d)<dis(S,d)$,不会被计入答案。
对于$y$与$T$之间的点的子树中的点(例如$c$),显然$dis(y,c)\le dis(y,T)$,所以这类点不会被计入答案。
那么综上所述对于靠近$S$的点,只有$x$到$T$之间的点才有可能被计入答案,对于靠近$T$的点同理。
所以我们只需要分别以$S$和$T$为根遍历整棵树,用一个单调栈保存每个点到根的这条链上能被计入答案的点即可。
求不同权值个数,再开一个桶记录栈中每种权值的个数,每次进栈或弹栈时对应加减。
因为答案与深度有关,我们将原树长链剖分。
对于每个点,当走重儿子时,求出所有轻儿子的子树中的最长链长度$len$,将当前栈中与$x$距离小于等于$len$的点弹出;当遍历轻儿子时,将当前栈中与$x$距离小于等于$x$往下最长链长度的点弹出。
注意要在弹栈之后再把$x$压入栈中,而且遍历每个儿子前都要重新将$x$压入栈中。
最后将以$S$为根时的答案与以$T$为根时的答案取最大值即可。
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<cstdio>
#include<vector>
#include<bitset>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
int head[200010];
int to[400010];
int dep[200010];
int next[400010];
int mx[200010];
int son[200010];
int st[200010];
int top;
int tot;
int S,T;
int res;
int ans[200010];
int n,m;
int x,y;
int col[200010];
int cnt[200010];
void add(int x,int y)
{
next[++tot]=head[x];
head[x]=tot;
to[tot]=y;
}
void pop()
{
cnt[col[st[top]]]--;
res-=(cnt[col[st[top]]]==0);
top--;
}
void push(int x)
{
st[++top]=x;
cnt[col[x]]++;
res+=(cnt[col[x]]==1);
}
void dfs(int x,int fa)
{
dep[x]=dep[fa]+1;
for(int i=head[x];i;i=next[i])
{
if(to[i]!=fa)
{
dfs(to[i],x);
}
}
}
void dfs1(int x,int fa)
{
son[x]=0;
mx[x]=0;
dep[x]=dep[fa]+1;
for(int i=head[x];i;i=next[i])
{
if(to[i]!=fa)
{
dfs1(to[i],x);
if(mx[to[i]]>mx[son[x]])
{
son[x]=to[i];
}
}
}
mx[x]=mx[son[x]]+1;
}
void dfs2(int x,int fa)
{
if(!son[x])
{
ans[x]=max(ans[x],res);
return ;
}
int len=0;
for(int i=head[x];i;i=next[i])
{
if(to[i]!=fa&&to[i]!=son[x])
{
len=max(len,mx[to[i]]);
}
}
while(top&&dep[st[top]]>=dep[x]-len)
{
pop();
}
push(x);
dfs2(son[x],x);
for(int i=head[x];i;i=next[i])
{
if(to[i]!=fa&&to[i]!=son[x])
{
while(top&&dep[st[top]]>=dep[x]-mx[son[x]])
{
pop();
}
push(x);
dfs2(to[i],x);
}
}
while(top&&dep[st[top]]>=dep[x]-mx[son[x]])
{
pop();
}
ans[x]=max(ans[x],res);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
for(int i=1;i<=n;i++)
{
scanf("%d",&col[i]);
}
dfs(1,0);
for(int i=1;i<=n;i++)
{
S=dep[i]>dep[S]?i:S;
}
dfs(S,0);
for(int i=1;i<=n;i++)
{
T=dep[i]>dep[T]?i:T;
}
dfs1(S,0);
dfs2(S,0);
dfs1(T,0);
dfs2(T,0);
for(int i=1;i<=n;i++)
{
printf("%d\n",ans[i]);
}
}