题目大意
给出一有 n n n 个点的树,现在要拆除一些线,但是需要保证至少有 k k k 个点,满足每个点都可以和至少一个点联通。
解题思路
做法一
考虑树形 DP
。
显然,对于两点一线的情况的个数,应该越多越好。
假设如果是 x x x 对点(两点一线),且 x ∗ 2 ≥ k x*2≥k x∗2≥k,那么只需要 ( k + 1 ) / 2 (k+1)/2 (k+1)/2 条边。
否则,说明有多出的点,多出的点要连出一条边与两点一线联通,则需要 x + ( k − x ∗ 2 ) x + (k-x*2) x+(k−x∗2) 条边。
现在问题就转为求这样的点对有多少。
设 d p i , 1 dp_{i,1} dpi,1 表示以 i i i 为根的子树中能够组成许多两点一线的最大点数,包含节点 i i i。
设 d p i , 0 dp_{i,0} dpi,0 为 ∑ v ∈ s o n s [ u ] d p v , 1 \sum\limits_{v \in sons[u] }^{} dp_{v,1} v∈sons[u]∑dpv,1。
当结点 u u u 与它的子结点 v v v 连边时,其余子结点都无法与结点 u u u 连边,并且结点 v v v 的子结点无法和结点 v v v 连边,所以 d p u , 1 = m a x ( d p u , 1 , d p u , 0 − d p v , 1 + d p v , 0 + 1 ) dp_{u,1}=max(dp_{u,1},dp_{u,0}-dp_{v,1}+dp_{v,0}+1) dpu,1=max(dpu,1,dpu,0−dpv,1+dpv,0+1)。
转移方程:
d p u , 0 = ∑ v ∈ s o n s [ u ] d p v , 1 dp_{u,0}=\sum\limits_{v \in sons[u] }^{} dp_{v,1} dpu,0=v∈sons[u]∑dpv,1
d p u , 1 = m a x ( d p u , 1 , d p u , 0 − d p v , 1 + d p v , 0 + 1 ) dp_{u,1}=max(dp_{u,1},dp_{u,0}-dp_{v,1}+dp_{v,0}+1) dpu,1=max(dpu,1,dpu,0−dpv,1+dpv,0+1)
最后 d p 1 , 1 dp_{1,1} dp1,1 就是上面所说的 x x x 了。
做法二
当然,有 DP
就有贪心,先思考这张图。
对于这张图,思考,我们要从上往下匹配,还是从下往上匹配。
显然是从下往上匹配,因为一个子节点往上只有一个父节点,但一个节点往下可能有多个子节点。
显然最优的情况一定是一条边匹配两个点,即两点一线。
然后从下往上匹配,每遇到两个没访问过的点,就相连,这就是贪心的思路。
但是有可能树的形状不能满足这样的匹配,于是剩下的点就只能用一条边匹配了。
AC CODE
树形 DP
#include <bits/stdc++.h>
using namespace std;
#define _ 100005
int T, n, k, ans;
int tot, head[_], to[_ << 1], nxt[_ << 1];
int f[_][2];
void add(int u, int v)
{
to[++tot] = v;
nxt[tot] = head[u];
head[u] = tot;
}
void dfs(int u, int fa)
{
for(int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if(v == fa) continue;
dfs(v, u);
f[u][0] += f[v][1];
}
f[u][1] = f[u][0];
for(int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if(v == fa) continue;
f[u][1] = max(f[u][1], f[u][0] - f[v][1] + f[v][0] + 1);
}
}
signed main()
{
scanf("%d", &T);
while(T--)
{
tot = 0;
memset(head, 0, sizeof head);
memset(f, 0, sizeof f);
scanf("%d%d", &n, &k);
for(int i = 1; i < n; ++i)
{
int a;
scanf("%d", &a);
add(i + 1, a);
add(a, i + 1);
}
dfs(1, 0);
int kkk = (k & 1);
k = k - kkk;
if(f[1][1] * 2 >= k) printf("%d\n", k / 2 + kkk);
else printf("%d\n", f[1][1] + (k - f[1][1] * 2) + kkk);
}
return 0;
}
贪心
#include<bits/stdc++.h>
using namespace std;
const int _ = 100005;
int tot, ans;
int head[_], nxt[_ << 1], to[_ << 1];
bool vis[_];
inline int read()
{
int X = 0, w = 1;
char ch = 0;
while(ch < '0' || ch > '9')
{
if(ch == '-') w = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') X = (X << 3) + (X << 1) + ch - '0', ch = getchar();
return X * w;
}
inline void add(int x, int y)
{
nxt[++tot] = head[x];
head[x] = tot;
to[tot] = y;
}
inline void dfs(int x, int y)
{
for(int i = head[x]; i; i = nxt[i])
if(to[i] != y)
{
dfs(to[i], x);
if(!vis[x] && !vis[to[i]]) vis[x] = vis[to[i]] = 1, ans++;
}
}
signed main()
{
int T = read();
while(T--)
{
int n = read(), k = read();
tot = 0;
ans = 0;
memset(head, 0, sizeof(head));
memset(vis, 0,sizeof(vis));
for(int i = 1; i < n; i++)
{
int a = read();
add(a, i + 1);
add(i + 1, a);
}
dfs(1, 0);
if(ans * 2 >= k)
printf("%d\n",(k + 1) / 2);
else
printf("%d\n", ans + (k - ans * 2));
}
return 0;
}