CF1029E-Tree with Small Distances
牛客链接
CF链接很好找,就不贴了
题意:
- 给你一颗树,要求你加入最少的边,使得1节点到任意节点的距离不超过2
- 数据范围,点数 20万
- 注意,这道题可能需要效率比较高的输入和树的遍历方式
做法:
- 我也好久没写过题了,虽然之前写过但不过都没有记录,现在还是记录一下,为了找工作做准备。
- 我拿到这道题,我一脸懵逼,WC怎么这么难,想了想不可能是贪心,因为我觉得贪心越复杂肯定是错误的,但不过贪心能不能做,我不知道,好像能做吧。
- 于是我就想到了动态规划,于是按照树形dp的套路,肯定是以 d p [ u ] dp[u] dp[u]表示以u节点为根的最优解。
- 因此就开始往这边想,但不过1维大部分时候是不够的,其实这个就很想那些经典的树形dp,比如最小覆盖等等。
- 首先这个是距离不超过2,这个就很有特点,因此一个点不是自己连向1就是通过自己的儿子或者父亲连向1
- 因此dp的状态如下
- d p [ u ] [ 1 ] dp[u][1] dp[u][1]自己连向1,并且下面的子节点都被覆盖。
- d p [ u ] [ 2 ] dp[u][2] dp[u][2]通过自己的父亲连向1,并且下面的子节点都被覆盖。
- d p [ u ] [ 3 ] dp[u][3] dp[u][3]通过自己的儿子连向1,并且下面的子节点都被覆盖。
- 既然这样定义了,那么就可以通过简单的逻辑来递推。
- 首先 d p [ u ] [ 1 ] dp[u][1] dp[u][1]肯定是从子节点递推上来的,递推式如下: d p [ u ] [ 1 ] = 1 + ∑ v m i n ( d p [ v ] [ 1 ] , d p [ v ] [ 2 ] , d p [ v ] [ 3 ] ) dp[u][1]= 1+\sum_{v}min(dp[v][1],dp[v][2],dp[v][3]) dp[u][1]=1+v∑min(dp[v][1],dp[v][2],dp[v][3])
- 对于上面的式子来说,如果u是1的儿子,可以不用加1。总的意思就是,我u节点是搞定了,直接选取每一个子树的最小的值就可以了。
- d p [ u ] [ 2 ] dp[u][2] dp[u][2]递推式如下: d p [ u ] [ 2 ] = ∑ v m i n ( d p [ v ] [ 1 ] , d p [ v ] [ 3 ] ) dp[u][2]=\sum_{v}min(dp[v][1],dp[v][3]) dp[u][2]=v∑min(dp[v][1],dp[v][3])
- 这个不能用从儿子节点的 d p [ v ] [ 2 ] dp[v][2] dp[v][2]状态递推过来,原因显而易见的。
- d p [ u ] [ 3 ] dp[u][3] dp[u][3]递推式如下: d p [ u ] [ 3 ] = ∑ v m i n ( d p [ v ] [ 1 ] , d p [ v ] [ 3 ] ) + ( s t a t u s = = 1 ) dp[u][3]=\sum_{v}min(dp[v][1],dp[v][3])+(status==1) dp[u][3]=v∑min(dp[v][1],dp[v][3])+(status==1)
- 上面也不能用从儿子节点的 d p [ v ] [ 2 ] dp[v][2] dp[v][2]状态递推过来,原因显而易见的。至于后面为什么加1,是因为这个状态是通过儿子连接的1节点,既然你没有选择儿子自己链接1的状态那么你肯定要加上1,这样可以证明是最优的,无非就是1 的差距。
- 最后统计答案的时候因为1节点可能有些特殊,最好还是通过1节点的儿子统计答案
代码:
#include <bits/stdc++.h>
using namespace std;
#define SZ(x) ((int)((x).size()))
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
#define pii pair<int,int>
#define pll pair<long long,long long>
#define rep(i, a, b) for(int i=(a);i<=(b);++i)
#define per(i, a, b) for(int i=(a);i>=(b);--i)
#define pb push_back
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 4e5 + 10;
int dp[maxn][4];
struct edge {
int v, nxt;
} ed[maxn << 1];
int head[maxn], cnt = 0;
void add(int u, int v) {
ed[++cnt] = edge{v, head[u]};
head[u] = cnt;
}
void dfs(int u, int f, int dep) {
if (dep >= 2) dp[u][1] = 1;
int flag = 0;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (v == f) continue;
dfs(v, u, dep + 1);
dp[u][1] += min(dp[v][1], min(dp[v][2], dp[v][3]));
dp[u][2] += min(dp[v][1], dp[v][3]);
if (dp[v][1] <= dp[v][3] && dp[v][1] != 0) flag = 1;
}
dp[u][3] = dp[u][2];
if (!flag) dp[u][3] += 1;
}
int main() {
// ios::sync_with_stdio(false);
// cin.tie(nullptr);
// cout.tie(nullptr);
int n;
scanf("%d", &n);
rep(i, 1, n - 1) {
int u, v;
scanf("%d%d", &u, &v);
add(u, v), add(v, u);
}
dfs(1, 0, 0);
int ans = 0;
for (int i = head[1]; i; i = ed[i].nxt) {
int v = ed[i].v;
ans += min(dp[v][1], min(dp[v][2], dp[v][3]));
}
cout << ans << endl;
return 0;
}