漂亮的花园
题解:思路:
结论:
对于一颗树。假设树上的两种节点颜色为x,y.那么节点x与节点y的最远距离为。x颜色类的节点中两个距离最远的点中的某一点.到y颜色类的节点中距离最远的两个点中的某一点.
证明:
设x颜色节点最远的两个点为x1,x2.
y颜色节点最远的两个点为y1,y2.
如果y1到x3距离很大。显然y1到x1,x2中的某一点的距离更大.
因为当x3在x1-x2的路径中时,显然到两端更大。
当不在路径中时,显然可以通过x3到路径上,且更大.
很抽象,画图更容易理解.
所以枚举四个点的距离即可.这里需要注意的是不一定所有颜色都有两个以上节点。
所有需要特判
因为c[i]过大,所以先离散化颜色.然后将各种颜色的点放入。
然后预处理出每种颜色的节点最远距离的两个点。
树上两点距离可以通过Lca计算.
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int mod=1e9+9;
const int maxn=1e5+10;
int a[maxn],b[maxn];
int n,q;
map<int,int> mp;
vector<int> col[maxn];
vector<int> G[maxn];
int f[maxn][21];
int p[21];
int dep[maxn];
void ko()
{
p[0]=1;
for(int i=1;i<=19;i++){
p[i]=p[i-1]*2;
}
}
void dfs(int u,int fa)
{
f[u][0]=fa;
dep[u]=dep[fa]+1;
for(int i=1;i<=19;i++){
f[u][i]=f[f[u][i-1]][i-1];
}
for(int i=0;i<G[u].size();i++){
int to=G[u][i];
if(to==fa)continue;
dfs(to,u);
}
}
void init()
{
ko();
dfs(1,0);
}
int lca(int x,int y)
{
if(dep[x]<dep[y])swap(x,y);
int d=dep[x]-dep[y];
for(int i=19;i>=0;i--){
if(d>=p[i]){
d-=p[i];
x=f[x][i];
}
}
if(x==y)return x;
for(int i=19;i>=0;i--){
if(dep[x]>=p[i]&&f[x][i]!=f[y][i]){
x=f[x][i];
y=f[y][i];
}
}
return f[x][0];
}
int dis(int x,int y)
{
return dep[x]+dep[y]-2*dep[lca(x,y)];
}
int main()
{
mp.clear();
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);mp[a[i]]++;
b[i]=a[i];
}
sort(a+1,a+1+n);
int len=unique(a+1,a+1+n)-(a+1);
for(int i=1;i<=n;i++){
int x=lower_bound(a+1,a+1+len,b[i])-a;
col[x].push_back(i);
}
int u,v;
for(int i=1;i<=n-1;i++){
scanf("%d%d",&u,&v);
G[u].push_back(v);G[v].push_back(u);
}
init();
for(int i=1;i<=len;i++){
for(int j=0;j<col[i].size();j++){
int d1=dis(col[i][0],col[i][j]);
int d2=dis(col[i][1],col[i][j]);
int d3=dis(col[i][0],col[i][1]);
if(d1>d3&&d2>d3){
if(d1>d2){
col[i][1]=col[i][j];
}
else{
col[i][0]=col[i][j];
}
}
else
if(d1>d3){
col[i][1]=col[i][j];
}
else
if(d2>d3){
col[i][0]=col[i][j];
}
}
}
while(q--){
int x,y;
scanf("%d%d",&x,&y);
if((!mp[x])||(!mp[y])){
printf("0\n");
}
else{
x=lower_bound(a+1,a+1+len,x)-a;
y=lower_bound(a+1,a+1+len,y)-a;
if(col[x].size()==1&&col[y].size()==1){
printf("%d\n",dis(col[x][0],col[y][0]));
}
else
if(col[x].size()==1&&col[y].size()>1){
printf("%d\n",max(dis(col[x][0],col[y][0]),dis(col[x][0],col[y][1])));
}
else
if(col[x].size()>1&&col[y].size()==1){
printf("%d\n",max(dis(col[y][0],col[x][0]),dis(col[y][0],col[x][1])));
}
else{
int x1=col[x][0],x2=col[x][1],y1=col[y][0],y2=col[y][1];
int d1=dis(x1,y1),d2=dis(x1,y2),d3=dis(x2,y1),d4=dis(x2,y2);
d1=max(d1,d2);
d1=max(d1,d3);
d1=max(d1,d4);
printf("%d\n",d1);
}
}
}
return 0;
}