Codeforces 1223E
Tag : 动态规划,树上DP,图论,贪心
题目分析
题目大意
给定一个无向带权树,求一个满足下列条件的路径集的最大边权和。
- 一个路径最多出现一次
- 同一个点最多被连接 k k k 次
做法猜测
贪心
首先考虑仿照 MST 的排序算法,毕竟 MST 也是选边,然后求的最小值。
但是发现这个做法是不正确的。
比如下面这个输入
1
4 1
1 2 2
2 3 3
3 4 2
正确的选择方式应该是选择两个 2 2 2 的边,而简单的 MST 贪心会选择那个长度为 3 3 3 的边。
当然如果你们贪出来了,我想学习一下
动态规划
既然不能做贪心,那就考虑贪心的上一个等级,也就是动态规划。
动态规划设计
前置说明
记当前节点为
u
u
u,子节点为
v
i
v_{i}
vi,子节点集合为
s
o
n
u
son_{u}
sonu
使用的编译指令为
g++ filename.cpp -o filename.exe -O2 -std=c++11
预处理头文件与常量定义为
#include<bits/stdc++.h>
using namespace std;
const int N = 5e5+17
存图方式为链式前向星,其中存图的代码可以参考下述
int n,tails,fr[N],to[N<<1],nxt[N<<1],ct[N<<1];
void add(int f,int t,int w){
to[++tails] = t;
nxt[tails] = fr[f];
fr[f] = tails;
ct[tails] = w;
return;
}
遍历的方式为
void DFS(int u,int fa){
for(int zj=fr[u],v;zj;zj=nxt[zj]){
if((v=to[zj]) == fa) continue;
DFS(v,start);
}
return;
}
状态说明V1.0
设计下列状态
F [ i ] [ j ] , i ∈ { 1 , 2 , 3 , . . . , n } , j ∈ { 0 , 1 , 2 , 3 , . . . , k } F[i][j] , i \in \{1,2,3,...,n\}, j \in \{0,1,2,3,...,k\} F[i][j],i∈{1,2,3,...,n},j∈{0,1,2,3,...,k}
表示 节点 i i i 连接至多 j j j 条边时的最优结果。
so easy!
状态分析V1.0
状态转移方程的确立
初始值需要考虑 u u u 对 v v v 的影响。
F [ v ] [ 0 ] = 0 , F [ v ] [ i ] = c t [ z j ] , i ∈ { 1 , 2 , 3 , . . . , k } F[v][0] = 0 , F[v][i] = ct[zj],i \in \{1,2,3,...,k\} F[v][0]=0,F[v][i]=ct[zj],i∈{1,2,3,...,k}
状态转移也可以设计出来
F
[
u
]
[
i
]
=
m
a
x
{
F
[
u
]
[
i
−
1
]
+
F
[
v
i
]
[
k
−
1
]
,
F
[
u
]
[
i
]
+
F
[
v
i
]
[
k
]
}
F[u][i] = max\{F[u][i-1] + F[v_{i}][k-1] , F[u][i] + F[v_{i}][k]\}
F[u][i]=max{F[u][i−1]+F[vi][k−1],F[u][i]+F[vi][k]}
Notice:
i
i
i 从大到小进行枚举
时空复杂度
空间复杂度
看起来是个
n
∗
k
n*k
n∗k 的数组,但是实际上我们可以通过统计一下总的儿子个数,然后结合
v
e
c
t
o
r
vector
vector 进行一波数据优化
但是这种做法应该能被卡出
M
L
E
MLE
MLE
时间复杂度
这不是 O ( n ∗ k ) O(n*k) O(n∗k) 吗
结合统计儿子个数的优化,能勉强前进一些。
小细节的优化
每次清空 f r fr fr 数组的时候,不必全部清空,全部清空的话,复杂度就完全的 O ( N ∗ q ) O(N*q) O(N∗q) 了
小结V1.0
显然,不行。
状态说明V2.0
重新考虑上面的转移方程,我们可以发现对于每个子节点 v v v 我们仅使用了其两个状态,所以我们不妨尝试一下是否能够使用两个状态来进行表示。
设计 F [ i ] [ 0 ] F[i][0] F[i][0] 表示原先的 F [ i ] [ k ] F[i][k] F[i][k],设计 F [ i ] [ 1 ] F[i][1] F[i][1] 表示原先的 F [ i ] [ k − 1 ] F[i][k-1] F[i][k−1]
然后考虑新的转移方程。
状态分析V2.0
状态转移方程的确立
v
1
v1
v1 为选择
u
−
>
v
u -> v
u−>v 连边的节点,即选择
F
[
v
1
]
[
1
]
F[v1][1]
F[v1][1],其中
v
1
v1
v1 的多少设为
p
p
p
v
0
v0
v0 为所有儿子中除去
v
1
v1
v1 的节点。
F [ u ] [ 0 ] = m a x p < k + 1 { ∑ F [ v 1 ] [ 1 ] + ∑ F [ v 0 ] [ 0 ] } F[u][0] = max_{p < k+1 }\{\sum F[v1][1] + \sum F[v0][0]\} F[u][0]=maxp<k+1{∑F[v1][1]+∑F[v0][0]}
F [ u ] [ 1 ] = m a x p < k { ∑ F [ v 1 ] [ 1 ] + ∑ F [ v 0 ] [ 0 ] } F[u][1] = max_{p < k}\{\sum F[v1][1] + \sum F[v0][0]\} F[u][1]=maxp<k{∑F[v1][1]+∑F[v0][0]}
贪心优化转移
既然 选择 0 0 0 状态的节点可以任意多,那么不妨先假设全选 0 状态,然后选择 p p p 个节点,将其由状态 0 0 0 转换为状态 1 1 1 。
产生一个排序,排序规则为由大到小,排序关键字如下
F [ v ] [ 1 ] [ − F [ v ] [ 0 ] F[v][1][ - F[v][0] F[v][1][−F[v][0]
然后依次取出前面的至多 p p p 个正数(取负数会导致结果变小)
更新出相应的结果。
小优化
因为两个状态的转移区别仅限于第 k k k 个值,所以我们可以特判第 k k k 个值,找出 两个状态的区别。
时空复杂度分析
时间复杂度
由于使用了排序,因此为
O ( n l o g n ) O(n log n) O(nlogn)
空间复杂度
请看代码
完整代码
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<iostream>
#include<cstring>
using namespace std;
const int N = 5e5 + 17;
int n, k, tails, fr[N], to[N << 1], nxt[N << 1], ct[N << 1];
void add(int f, int t, int w) {
to[++tails] = t;
nxt[tails] = fr[f];
fr[f] = tails;
ct[tails] = w;
return;
}
void input() {
tails = 0;
scanf("%d %d", &n, &k);
for (int i = 1,p1,p2,p3; i < n; ++i) {
scanf("%d %d %d", &p1, &p2, &p3);
add(p1, p2, p3); add(p2, p1, p3);
}
}
long long F[N][2], tempSort[N];
bool comp(long long a, long long b) {
return a>b;
}
void DFS(int u, int fa) {
long long totalZero = 0;
int sons = 0;
F[u][0] = F[u][1] = 0;
for (int zj = fr[u],v; zj; zj = nxt[zj]) {
if ((v = to[zj]) == fa) continue;
DFS(v, u); ++sons;
F[v][1] += ct[zj];
totalZero += F[v][0];
}
if(sons == 0) return;
int cnt = 0;
for (int zj = fr[u], v; zj; zj = nxt[zj]) {
if ((v = to[zj]) == fa) continue;
tempSort[++cnt] = F[v][1] - F[v][0];
}
sort(tempSort + 1, tempSort + cnt + 1, comp);
int minfors = min(sons,k);
for(int i = 1;i <= minfors;++i)
if(tempSort[i] > 0)
totalZero += tempSort[i];
else
break;
F[u][1] = F[u][0] = totalZero;
if(sons >= k && tempSort[k] > 0)
F[u][1] -= tempSort[k];
return;
}
void Clear(){
memset(fr,0,(n+1)<<2);
}
void Work() {
input();
DFS(1, 0);
printf("%I64d\n", F[1][0]);
Clear();
}
int main() {
int t; scanf("%d", &t);
while (t--) Work();
return 0;
}