发现网上都是o(nlogn)的写法,没有o(n)的,所以这里贴一下代码,简单讲解一下(感觉好像很难讲明白,我实在是太菜了QAQ)
做法1:刚看题时的思路,不是很好理解,建议看做法2
显然如果派出一支军队访问子树,并回到子树根节点,那消耗的时间就是边长的2倍。首先,显然军队从根出发,经过一些路径后停留在叶子节点明显最优。考虑多引一支军队到达叶子节点,则该军队会先走过一段被其他军队走过的路径,会额外消耗一些时间。到达子树根节点后,子树根节点到某个叶子节点这段只用走一次,可以节省这段路径长度的时间。
dp[x]表示只考虑x节点所在的子树以及x指向父亲的那条边,至少多派一支军队走到x所在子树的叶子节点(且已经有其他军队进过x子树的父亲节点,且之前没有任何一个军队会走到x子树的叶子节点),最少会多消耗多少时间(负数表示能节省时间)。
转移的时候,如果有子树额外引一支军队能节省时间(即dp值小于0),那显然引一个军队走到该子树会使答案更优,如果没有任何一支军队dp值小于0,则选额外消耗时间最少的那一颗子树多派一支军队
#include<bits/stdc++.h>
using namespace std;
int dp[1001000];
int n;
vector<int>edg[1001000];
int Ti = 0;
void dfs(int now, int d) {
if(edg[now].size() == 0) {
dp[now] = (d-1) - 1;
return;
}
int sum = 0;
int num = 0;
int minn = 10000000;
for (int i = 0; i < edg[now].size(); i++) {
int nex = edg[now][i];
dfs(nex, d+1);
if (dp[nex] <= 0) {
num++;
sum+=dp[nex];
}
minn = min(minn, dp[nex]);
}
if (num > 0) {
dp[now] = sum - 2;
} else {
dp[now] = minn - 2;
}
}
void solve(){
Ti++;
scanf("%d",&n);
for (int i = 1; i <= n; i++) {
edg[i].clear();
dp[i] = 0;
}
for (int i = 2; i <= n; i++) {
int u;
scanf("%d",&u);
edg[u].push_back(i);
}
dfs(1, 0);
int ans = 0;
for (int i = 0;i < edg[1].size(); ++i) {
if (dp[edg[1][i]] < 0) {
ans += dp[edg[1][i]];
}
}
ans+= 2*(n-1);
printf("Case #%d: %d\n",Ti,ans);
}
int main(){
int T;
scanf("%d",&T);
while(T--) {
solve();
}
}
做法2:
换另一种DP思路,DP[0/1][n] 中 DP[0][i] 表示只考虑 i 节点所在的子树以及 i节点 指向父亲的那条边(对于6号节点来说考虑的范围是黄色子树,对于4号节点来说考虑的范围是红色子树),不派军队(DP[1][i]为派至少一只军队)的最小花费。
显然DP[0][i]等于子树 i 的节点数。
如果i节点为叶子节点,则DP[1][i] = dep[i],(即派一只军队走到该叶子结点的代价)dep[i]为节点i的深度(根节点为0)
如果i节点不为叶子节点,则DP[1][i] = min{DP[0][son1],DP[1][son1]} + min{DP[0][son2],DP[1][son2]} …… 其中son表示该节点的儿子节点。而且,假如所有的节点都是不派军队更优,则选一个 DP[1][son1] - DP[0][son1] 最小的,加上这个代价。
最后输出DP[1][1]即是答案
#include<bits/stdc++.h>
using namespace std;
int n, dep[1000010], num[1000010], dp[2][1000010];
vector<int> g[1000010];
//求子树size
void pre_dfs(int u) {
num[u] = 1;
for (int i = 0; i < g[u].size(); ++i) {
int v = g[u][i];
dep[v] = dep[u] + 1;
pre_dfs(v);
num[u] += num[v];
}
}
void dfs(int u) {
//如果是叶子节点
dp[0][u] = num[u] * 2;
if (!g[u].size()) {
dp[1][u] = dep[u];
return ;
}
//mn = dp[1][v] - dp[0][v] 的最小值
int mn = 1e9;
dp[1][u] = 0;
for (int i = 0; i < g[u].size(); ++i) {
int v = g[u][i];
dfs(v);
dp[1][u] += min(dp[0][v], dp[1][v]);
mn = min(mn, dp[1][v] - dp[0][v]);
}
if (mn > 0) {
dp[1][u] += mn;
}
}
int Ti = 0;
void solve(){
Ti++;
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
g[i].clear();
}
for (int i = 2; i <= n; ++i) {
int f;
scanf("%d", &f);
g[f].push_back(i);
}
dep[1] = 0;
pre_dfs(1);
dfs(1);
printf("Case #%d: %d\n", Ti, dp[1][1]);
}
int main() {
int T;
cin >> T;
while(T--) {
solve();
}
return 0;
}