题目
代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e4+10, M = N << 1;
const int INF = 0x3f3f3f3f;
int f[N][2];
int c[N];
int m, n;
int h[N], e[M], ne[M], idx;
void add(int a, int b) // 添加一条边a->b
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
void dfs(int p, int u)
{
f[u][0] += 1;
f[u][1] += 1;
for(int k = h[u]; ~k; k = ne[k])
{
int j = e[k];
if(j == p) continue;
dfs(u, j);
f[u][0] += min(f[j][0] - 1, f[j][1]);
f[u][1] += min(f[j][0], f[j][1] - 1);
}
}
int main()
{
cin >> m >> n;
for(int i = 1; i <= n; i++) cin >> c[i], f[i][!c[i]] = INF;
memset(h, -1, sizeof h);
for(int i = 1; i < m; i++)
{
int a, b;
cin >> a >> b;
add(a, b);
add(b, a);
}
dfs(0, n+1);
cout << min(f[n+1][0], f[n+1][1]);
return 0;
}