思路:
只需要找到一颗子树包含所有颜色,然后求与它最远的点就行了
怎么判断颜色呢?
把所有颜色求lca,然后用差分。
c o d e code code
#include<iostream>
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
const int MAXN = 1e6 + 10;
int n, m, cnt;
int f[MAXN][21], max_dep[MAXN][2], max_up[MAXN];
int bz[MAXN], bz2[MAXN], dep[MAXN], dfn[MAXN];
vector<int> col[MAXN], b[MAXN];
bool cmp(int x, int y) {
return dfn[x] < dfn[y];
}
void dfs(int x, int fa) {
dfn[x] = ++ cnt;
for(int i = 1; i <= 20; i ++)
f[x][i] = f[f[x][i - 1]][i - 1];
dep[x] = dep[fa] + 1;
for(int i = 0; i < b[x].size(); i ++) {
int y = b[x][i];
if(y == fa) continue;
f[y][0] = x;
dfs(y, x);
if(max_dep[y][0] + 1 > max_dep[x][0]) max_dep[x][1] = max_dep[x][0], max_dep[x][0] = max_dep[y][0] + 1;
else if(max_dep[y][0] + 1 > max_dep[x][1]) max_dep[x][1] = max_dep[y][0] + 1;
}
}
int lca(int x, int y) {
if(dep[x] > dep[y]) swap(x, y);
int k = dep[y] - dep[x], t = 1 << 20, j = 20;
while(k) {
if(k >= t) k -= t, y = f[y][j];
j --, t = 1 << j;
}
if(x == y) return x;
for(int i = 20; i >= 0; i --)
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
void dfs_t(int x, int fa) {
for(int i = 0; i < b[x].size(); i ++) {
int y = b[x][i];
if(y == fa) continue;
max_up[y] = max(max_up[y], max_up[x] + 1);
if(max_dep[x][0] == max_dep[y][0] + 1)
max_up[y] = max(max_up[y], max_dep[x][1] + 1);
else max_up[y] = max(max_up[y], max_dep[x][0] + 1);
dfs_t(y, x);
bz[x] += bz[y];
bz2[x] |= bz2[y];
}
}
int main() {
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i ++) {
int x;
scanf("%d", &x);
col[x].push_back(i);
}
for(int i = 1; i < n; i ++) {
int x, y;
scanf("%d%d", &x, &y);
b[x].push_back(y), b[y].push_back(x);
}
dfs(1, 0);
for(int i = 1; i <= m; i ++) {
sort(col[i].begin(), col[i].end(), cmp);
int k = 0;
for(int j = 0; j < col[i].size(); j ++) {
bz[col[i][j]] ++;
if(k == 0) k = col[i][j];
else {
k = lca(k, col[i][j]);
bz[lca(col[i][j - 1], col[i][j])] --;
}
}
bz2[k] = true;
}
dfs_t(1, 0);
int ans = 0;
for(int i = 1; i <= n; i ++) {
if(bz[i] == m) ans = max(ans, max_up[i]);
if(!bz2[i]) ans = max(ans, max_dep[i][0] + 1);
}
printf("%d", ans + 1);
return 0;
}