参考程序:
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+10;
// 图的邻接表
vector<int> g[N];
// 每个点的颜色,0表示白色,1表示黑色
int col[N];
// 树的节点个数
int n;
// 每个节点的深度
int dep[N];
// far[u][0] 表示在以 u 为根的子树中,距离 u 最远的白色点的深度
// far[u][1] 表示在以 u 为根的子树中,距离 u 最远的黑色点的深度
int far[N][2];
// 最终答案
int ans = 0;
// 深度优先搜索
void dfs(int x, int fa){
dep[x] = dep[fa] + 1; // 当前节点的深度比父节点多1
far[x][col[x]] = dep[x]; // 以自己为根,颜色相同的最远点就是自己
for(int i : g[x]){
if(i != fa){
dfs(i, x); // 遍历子树
// 枚举白-黑配对和黑-白配对情况
for(int j = 0; j < 2; j++){
// 如果当前节点 x 在颜色 j 有最远点,子节点 i 在颜色 j^1 有最远点
if(far[x][j] != -1 && far[i][j^1] != -1){
// 更新答案(两个路径深度减去2倍dep[x]得到真实路径长度)
ans = max(ans, far[x][j] - dep[x] + far[i][j^1] - dep[x]);
}
}
// 将子节点的信息合并回当前节点
for(int j = 0; j < 2; j++){
far[x][j] = max(far[x][j], far[i][j]);
}
}
}
// 当前节点自己向上传递,看它能不能与祖先形成不同色对
if(far[x][col[x]^1] != -1){
ans = max(ans, far[x][col[x]^1] - dep[x]);
}
}
int main(){
cin >> n;
memset(far, -1, sizeof far); // 初始化为-1表示没有颜色 j 的点
for(int i = 1; i <= n; i++){
cin >> col[i]; // 读入每个节点的颜色
}
// 构建无向树
for(int i = 1; i < n; i++){
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 0); // 从根节点 1 开始 DFS
cout << ans << "\n"; // 输出最终答案
}