题意看这篇博客:https://blog.csdn.net/dreaming__ldx/article/details/88418543
思路看这篇:https://blog.csdn.net/corsica6/article/details/88115948
有个坑点,不能深搜去找具体方案,不然 test14 会 MLE(或许是本蒟蒻写丑了)
代码:
#include <bits/stdc++.h>
#define LL long long
#define pii pair<int, int>
using namespace std;
const int maxn = 200010;
int head[maxn], Next[maxn * 2], ver[maxn * 2], tot;
int a[maxn];
LL dp[maxn][2], sum[maxn], f[maxn];
bool v[maxn][2];
bool res[maxn];
vector<int> ans;
bool is_leaf[maxn];
void add(int x, int y) {
ver[++tot] = y;
Next[tot] = head[x];
head[x] = tot;
}
void dfs1(int x, int fa) {
int cnt = 0;
f[x] = fa;
for (int i = head[x]; i; i = Next[i]) {
int y = ver[i];
if(y == fa) continue;
dfs1(y, x);
sum[x] += dp[y][0];
cnt++;
}
if(cnt == 0) {
dp[x][0] = a[x];
dp[x][1] = 0;
is_leaf[x] = 1;
return;
}
for (int i = head[x]; i; i = Next[i]) {
int y = ver[i];
if(y == fa) continue;
dp[x][0] = min(dp[x][0], sum[x] - dp[y][0] + dp[y][1] + a[x]);
dp[x][1] = min(dp[x][1], sum[x] - dp[y][0] + dp[y][1]);
}
dp[x][0] = min(dp[x][0], sum[x]);
}
void bfs() {
queue<pii> q;
q.push(make_pair(1, 0));
while(!q.empty()) {
pii tmp = q.front();
q.pop();
int x = tmp.first, flag = tmp.second;
if(v[x][flag]) continue;
v[x][flag] = 1;
if(flag == 0) {
int pos = -1, cnt = 0;
if(is_leaf[x]) {
res[x] = 1;
continue;
}
for (int i = head[x]; i; i = Next[i]) {
int y = ver[i];
if(y == f[x]) continue;
if(dp[x][flag] == sum[x] - dp[y][0] + dp[y][1] + a[x]) {
if(v[y][1]) continue;
res[x] = 1;
q.push(make_pair(y, 1));
pos = y;
cnt++;
}
}
for (int i = head[x]; i; i = Next[i]) {
int y = ver[i];
if(v[y][0]) continue;
if(y == f[x] || y == pos) continue;
q.push(make_pair(y, 0));
}
if(cnt > 1 || (sum[x] == dp[x][0] && pos != -1)) {
if(v[pos][0]) continue;
q.push(make_pair(pos, 0));
}
} else {
int pos = -1, cnt = 0;
for (int i = head[x]; i; i = Next[i]) {
int y = ver[i];
if(y == f[x]) continue;
if(dp[x][flag] == sum[x] - dp[y][0] + dp[y][1]) {
if(v[y][1]) continue;
q.push(make_pair(y, 1));
pos = y;
cnt++;
}
}
for (int i = head[x]; i; i = Next[i]) {
int y = ver[i];
if(v[y][0]) continue;
if(y == f[x] || y == pos) continue;
q.push(make_pair(y, 0));
}
if(cnt > 1) {
if(v[pos][0]) continue;
q.push(make_pair(pos, 0));
}
}
}
}
int main() {
int n, x, y;
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &a[i]);
for (int i = 1; i < n; i++) {
scanf("%d%d", &x, &y);
add(x, y);
add(y, x);
}
memset(dp, 0x3f, sizeof(dp));
dfs1(1, -1);
bfs();
for (int i = 1; i <= n; i++) {
if(res[i])
ans.push_back(i);
}
printf("%lld %d\n", dp[1][0], ans.size());
sort(ans.begin(), ans.end());
for (int i = 0; i < ans.size(); i++)
printf("%d ", ans[i]);
printf("\n");
}
最小生成树解法先留个坑在这。。。
补坑:
思路看这篇博客:https://www.cnblogs.com/river-flows-in-you/p/10596821.html
说一下我个人的理解:把每个叶子节点看成新图的顶点,对于原树中的每个顶点,我们可以计算出它影响哪些叶子节点,用差分的思想连边。只要求出了生成树就说明可以任意取,因为形成生成树之后,我们对每个点的赋值操作就可以类比在生成树上遍历的过程。
代码:
#include <bits/stdc++.h>
#define LL long long
#define INF 0x3f3f3f3f
using namespace std;
const int maxn = 200010;
int head[maxn], Next[maxn * 2], ver[maxn * 2], tot;
int sz[maxn], cnt, l[maxn], r[maxn], a[maxn], f[maxn];
bool v[maxn];
struct Edge{
int u, v, w, id;
bool operator < (const Edge& rhs) const {
return w < rhs.w;
}
};
Edge b[maxn];
void add(int x, int y) {
ver[++tot] = y;
Next[tot] = head[x];
head[x] = tot;
}
void dfs(int x, int fa) {
sz[x] = 1;
l[x] = INF, r[x] = -1;
for (int i = head[x]; i; i = Next[i]) {
int y = ver[i];
if(y == fa) continue;
dfs(y, x);
sz[x] += sz[y];
l[x] = min(l[x], l[y]);
r[x] = max(r[x], r[y]);
}
if(sz[x] == 1) {
l[x] = r[x] = ++cnt;
}
b[x] = (Edge){l[x], r[x] + 1, a[x], x};
}
int get(int x) {
if(x == f[x]) return x;
return f[x] = get(f[x]);
}
int main() {
int n, x, y;
scanf("%d", &n);
int sum = 0;
LL ans = 0;
for (int i = 1; i <= n; i++)
scanf("%d", &a[i]);
for (int i = 1; i < n; i++) {
scanf("%d%d", &x, &y);
add(x, y);
add(y, x);
}
dfs(1, -1);
sort(b + 1, b + 1 + n);
cnt++;
for (int i = 1; i <= cnt; i++) f[i] = i;
for(int L = 1, R; L <= n; L = R + 1) {
R = L;
while(b[L].w == b[R + 1].w && R < n) R++;
for (int i = L; i <= R; i++) {
x = get(b[i].u), y = get(b[i].v);
if(x != y) {
v[b[i].id] = 1;
sum++;
}
}
for (int i = L; i <= R; i++) {
x = get(b[i].u), y = get(b[i].v);
if(x != y) {
ans += b[i].w;
// printf("%lld %d\n", ans, b[i].w);
f[x] = y;
}
}
}
// printf("%lld %d\n", ans, sum);
cout << ans << " " << sum << endl;
for (int i = 1; i <= n; i++)
if(v[i]) printf("%d ", i);
}