题目大意
有一个有 n n n个节点的树,每条边都有边权。给你一个整数 k k k,你要在书中选择 k k k个点,将这些点染成黑色,并将其他 n − k n-k n−k点染成白色。当所有点都染色后,你的收益为黑点两两之间的距离和加上白点两两之间的距离和,求收益的最大值是多少。
0 ≤ k ≤ n ≤ 2000 0\leq k\leq n\leq 2000 0≤k≤n≤2000
题解
我们考虑每条边的贡献,一条边的贡献为所有经过这条边且两段颜色相同的路径条数乘上边权,也就是边两侧黑色点数的乘积加上两侧白色点数的乘积之和再乘上边权。设一条边连接两个点 u , v u,v u,v( u u u是 v v v的父亲),边权为 w w w, v v v的子树大小为 s i z v siz_v sizv,在 v v v中选了 p p p个点,则这条边被经过的次数为
t m p = p × ( k − p ) + ( s i z v − p ) ( n − k − s i z v + p ) tmp=p\times (k-p)+(siz_v-p)(n-k-siz_v+p) tmp=p×(k−p)+(sizv−p)(n−k−sizv+p)
那么,DP式为
f u , j = max ( f u , j , f i , j − k + f v , k + t m p × w ) f_{u,j}=\max(f_{u,j},f_{i,j-k}+f_{v,k}+tmp\times w) fu,j=max(fu,j,fi,j−k+fv,k+tmp×w)
如果直接用这个式子的话,时间复杂度是 O ( n 3 ) O(n^3) O(n3)的(当然,这是跑不满的),这在洛谷上可以过前 10 10 10个点,拿 100 100 100分,但第 11 11 11个点会 T L E TLE TLE。为了方便理解,我先把这个代码放在下面,然后再说如何优化。
code
#include<bits/stdc++.h>
using namespace std;
int n,k,x,y,tot=0,d[5005],l[25005],r[5005],siz[2005];
long long z,w[5005],f[2005][2005];
void add(int xx,int yy,long long zz){
l[++tot]=r[xx];d[tot]=yy;r[xx]=tot;w[tot]=zz;
}
void dfs1(int u,int fa){
siz[u]=1;
for(int i=r[u];i;i=l[i]){
if(d[i]==fa) continue;
dfs1(d[i],u);
siz[u]+=siz[d[i]];
}
}
void dfs2(int u,int fa){
for(int i=0;i<=k;i++) f[u][i]=-1e15;
f[u][0]=f[u][1]=0;
for(int i=r[u];i;i=l[i]){
if(d[i]==fa) continue;
dfs2(d[i],u);
for(int j=min(k,siz[u]);j>=0;j--){
for(int p=0;p<=min(j,siz[d[i]]);p++){
int tmp=p*(k-p)+(siz[d[i]]-p)*(n-k-siz[d[i]]+p);
f[u][j]=max(f[u][j],f[u][j-p]+f[d[i]][p]+tmp*w[i]);
}
}
}
}
int main()
{
scanf("%d%d",&n,&k);
for(int i=1;i<n;i++){
scanf("%d%d%lld",&x,&y,&z);
add(x,y,z);add(y,x,z);
}
dfs1(1,0);
dfs2(1,0);
printf("%lld",f[1][k]);
return 0;
}
下面考虑优化。
我们将每个点的各个儿子中子树的节点数量最多的儿子称为重儿子,其余为轻儿子。重儿子与父亲的连边称为重边,其余边称为轻边,重边连成的链称为重链。那么每次先遍历重儿子,此时因为只有 f u , 0 f_{u,0} fu,0和 f u , 1 f_{u,1} fu,1有值,所以我们可以 O ( k ) O(k) O(k)将 u u u的重儿子的贡献算到 u u u中。对于轻儿子,我们还是按原来那样 O ( k × s i z v ) O(k\times siz_v) O(k×sizv)转移。
轻儿子的 O ( k × s i z v ) O(k\times siz_v) O(k×sizv)转移,可以看作 v v v的子树中每个点都贡献了 O ( k ) O(k) O(k)的时间复杂度。也就是说,每个点对所有轻边转移的时间复杂度的贡献为这个点到根节点上轻边的数量乘上 k k k。每个点到根节点的路径上最多只会有 log n \log n logn条轻边(每从轻儿子沿轻边向上,子树大小至少为原来的两倍),所以一个点的贡献为 O ( k log n ) O(k\log n) O(klogn),所有点的总贡献为 O ( n k log n ) O(nk\log n) O(nklogn)。
再考虑重链的贡献。因为两条轻边之间至多只有一条重链,所以最多只有 log n \log n logn条重链。每条重链的转移是 O ( k ) O(k) O(k)的,那么所有重链对时间复杂度的贡献为 O ( n k ) O(nk) O(nk)。
那么,总时间复杂度为 O ( n k log n ) O(nk\log n) O(nklogn)。
code
#include<bits/stdc++.h>
using namespace std;
int n,k,x,y,tot=0,d[5005],l[25005],r[5005],siz[2005],son[2005];
long long z,w[5005],tf[2005],f[2005][2005];
void add(int xx,int yy,long long zz){
l[++tot]=r[xx];d[tot]=yy;r[xx]=tot;w[tot]=zz;
}
void dfs1(int u,int fa){
siz[u]=1;
for(int i=r[u];i;i=l[i]){
if(d[i]==fa) continue;
dfs1(d[i],u);
siz[u]+=siz[d[i]];
if(siz[d[i]]>siz[son[u]]){
son[u]=d[i];tf[u]=w[i];
}
}
}
void dfs2(int u,int fa){
for(int i=0;i<=k;i++) f[u][i]=-1e15;
f[u][0]=f[u][1]=0;
if(son[u]){
dfs2(son[u],u);
for(int j=0;j<=min(k,siz[son[u]]);j++){
int tmp=j*(k-j)+(siz[son[u]]-j)*(n-k-siz[son[u]]+j);
f[u][j]=max(f[u][j],f[son[u]][j]+tmp*tf[u]);
f[u][j+1]=max(f[u][j+1],f[son[u]][j]+tmp*tf[u]);
}
}
for(int i=r[u];i;i=l[i]){
if(d[i]==fa||d[i]==son[u]) continue;
dfs2(d[i],u);
for(int j=min(k,siz[u]);j>=0;j--){
for(int p=0;p<=min(j,siz[d[i]]);p++){
int tmp=p*(k-p)+(siz[d[i]]-p)*(n-k-siz[d[i]]+p);
f[u][j]=max(f[u][j],f[u][j-p]+f[d[i]][p]+tmp*w[i]);
}
}
}
}
int main()
{
scanf("%d%d",&n,&k);
for(int i=1;i<n;i++){
scanf("%d%d%lld",&x,&y,&z);
add(x,y,z);add(y,x,z);
}
dfs1(1,0);
dfs2(1,0);
printf("%lld",f[1][k]);
return 0;
}