问题重述
给你一个无向树,其中每条边有AB两种边权。给定一个k** k ≤ 20 k\leq 20 k≤20**,让你选k条边权值为A,剩下的边权值为B,最小化该条件下的树的直径。
问题分析
树的直径两种求法
- 两次BFS可以找到树的直径, 基于规则是图没有变化, 显然这个题目中不太适用.
- 一次DP可以找到树的直径, 并且DP方程是每次处理一个子树 ( 感觉可以为我所用? )
然后就选了一次DP来求解.
DP讲解
讲真当时做这个题没看到
k
≤
20
k\leq 20
k≤20的条件, 不然还有可能冲出来?
一看这个k这么小, 就想着能不能直接 DP[N][K], 表示N为根节点的子树里面, 选了K个A, 剩下的全选B的情况下, 最长的链长度.
然后想一下当时怎么求树的直径: 维护一个DP[N]表示以N为根的子树最长链的长度, 然后每次用DP[v]更新DP[u]的时候统计一次最大值.
那么我们要是想直接计算题目中的树的直径好像并不简单.
考虑把这个问题转化为二分答案, 即已知树的直径不能超过某个值mid, 求是否存在一种方案使得其可以满足.
另DP[u][k]表示以u为根节点的子树上, 选了k条A边的最长链长度, 初始条件: DP[叶子][0]=0, 考虑背包式的更新:
D P [ u ] [ k + 1 ] = max ( D P [ u ] [ p ] , D P [ v ] [ k − p ] + A ) , 如 果 D P [ u ] [ p ] + D P [ v ] [ l ] + A ≤ m i d D P [ u ] [ k ] = max ( D P [ u ] [ p ] , D P [ v ] [ k − p ] + B ) , 如 果 D P [ u ] [ p ] + D P [ v ] [ l ] + B ≤ m i d DP[u][k+1] = \max({DP[u][p], DP[v][k-p] + A}), 如果 DP[u][p] + DP[v][l] + A \leq mid\\ DP[u][k] = \max({DP[u][p], DP[v][k-p] + B}), 如果 DP[u][p] + DP[v][l] + B \leq mid DP[u][k+1]=max(DP[u][p],DP[v][k−p]+A),如果DP[u][p]+DP[v][l]+A≤midDP[u][k]=max(DP[u][p],DP[v][k−p]+B),如果DP[u][p]+DP[v][l]+B≤mid
默认初始值都搞成 + ∞ +\infty +∞(或者是mid+1), 能用就行.
最后的判断条件自然就是DP[root][k]是否小于等于mid了!
代码
#include<bits/stdc++.h>
using namespace std;
const int N = 2e4+17, M = 4e4+17, K=21;
int fr[N], to[M], nxt[M], len1[M], len2[M], tails, size[N];
void add(int f, int t, int l1, int l2){
to[++tails] = t;
nxt[tails] = fr[f];
fr[f] = tails;
len1[tails] = l1;
len2[tails] = l2;
}
int n, m;
long long l, r, ans, mid, dp[N][K], tp[K];
void Check(int u, int fat){
size[u] = dp[u][0] = 0;
for(int p=fr[u], v;p;p=nxt[p]){
if((v=to[p])==fat) continue;
Check(v, u);
int l1 = len1[p], l2 = len2[p];
int size1 = size[u], size2 = size[v];
int size3 = min(m, size1+size2+1);
for(int j=0;j<=size3;++j) tp[j] = mid+1;
for(int j=0;j<=size1; ++j)
for(int k=0;k<=size2 && j+k<=m;++k){
if(dp[u][j] + dp[v][k] + l1 <= mid)
tp[j+k+1] = min(tp[j+k+1], max(dp[u][j], dp[v][k]+l1));
if(dp[u][j] + dp[v][k] + l2 <= mid)
tp[j+k] = min(tp[j+k], max(dp[u][j], dp[v][k]+l2));
}
size[u] = size3;
for(int j=0;j<=size3;++j)
dp[u][j] = tp[j];
}
return;
}
void work(){
scanf("%d %d",&n,&m); tails = l = r = 0;
for(int i=1;i<=n;++i) fr[i] = 0;
for(int i=1,p1,p2,l1,l2;i<n;++i){
scanf("%d%d%d%d",&p1, &p2, &l1, &l2);
add(p1,p2,l1,l2); add(p2,p1,l1,l2);
r += max(l1, l2);
}
while(l <= r){
mid = (l+r)>>1;
Check(1, 0);
if(dp[1][m] <= mid){
ans = mid;
r = mid-1;
}else{
l = mid+1;
}
}
printf("%lld\n", ans);
}
int main(){
int T;scanf("%d",&T);
while(T--) work();
return 0;
}