一个简单的dp ,统计每一个结点包括自身和所有下属结点的个数sum,并统计每一个结点感染后可以保存最多多少个结点ss表示,如果他的分支为0就是0为1就是他的儿子结点的sum-1,如果为2则是max(a[pace[0]].sum - 1 + a[pace[1]].ss, a[pace[1]].sum - 1 + a[pace[0]].ss);然后就是一些小细节的实现了
代码:
#include<bits/stdc++.h>
#define AC return 0;
#define int long long
using namespace std;
typedef pair<int, int> pii;
const int maxx = 300005;
const int mod = 1e9 + 7;
int n, m, t;
struct node {
int sum;
vector<int>next;
bool vis;
int deep;
int ss;
node() {
deep = 0;
vis = 0;
sum = 0;
ss = 0;
next.clear();
}
} a[maxx];
int dfs(int k) {
a[k].vis = 1;
vector<int>pace;
int len = a[k].next.size();
for (int i = 0; i < len; i++)if(!a[a[k].next[i]].vis)pace.push_back(a[k].next[i]);
if (a[k].next.size() == 1 && k != 1) {
a[k].sum = 1;
a[k].deep = 1;
return a[k].sum;
} else {
a[k].sum = 1;
int p = a[k].next.size();
for (int i = 0; i < p ; i++) {
if (!a[a[k].next[i]].vis)a[k].sum += dfs(a[k].next[i]);
}
}
if (pace.size() == 0) {
a[k].ss = 0;
} else if (pace.size() == 1) {
a[k].ss = a[pace[0]].sum - 1;
} else {
a[k].ss = max(a[pace[0]].sum - 1 + a[pace[1]].ss, a[pace[1]].sum - 1 + a[pace[0]].ss);
}
return a[k].sum;
}
int ans = 0;
signed main() {
cin >> t;
while (t--) {
cin >> n;
ans = 0;
for (int i = 0; i <= n; i++) {
a[i].next.clear();
a[i].sum = 0;
a[i].vis = 0;
a[i].deep = 0;
a[i].ss = 0;
}
for (int i = 1; i < n; i++) {
int x, y;
cin >> x >> y;
a[x].next.push_back(y);
a[y].next.push_back(x);
}
dfs(1);
cout << a[1].ss << '\n';
}
AC
}