题目描述:
题目分析:
先考虑
k
=
n
k=n
k=n的情况,这是一个经典问题,结论为所有点到重心的距离和。
证明:
对任意一点
u
u
u,都有
d
i
s
(
p
i
,
p
i
+
1
)
≤
d
i
s
(
p
i
,
u
)
+
d
i
s
(
u
,
p
i
+
1
)
dis(p_i,p_{i+1})\le dis(p_i,u)+dis(u,p_{i+1})
dis(pi,pi+1)≤dis(pi,u)+dis(u,pi+1)
那么就有
∑
i
=
1
n
d
i
s
(
p
i
,
p
i
m
o
d
n
+
1
)
≤
2
∗
∑
i
=
1
n
d
i
s
(
p
i
,
u
)
\sum_{i=1}^ndis(p_i,p_{i\bmod n+1})\le2*\sum_{i=1}^ndis(p_i,u)
∑i=1ndis(pi,pimodn+1)≤2∗∑i=1ndis(pi,u)
当点
u
u
u为重心时,上式可以取到等号,即左边取到最大值。
因为重心的子树大小一定
≤
n
/
2
\le n/2
≤n/2,形象地理解就是我们可以在不同的子树间反复横跳,构造这样的序列也是很容易的,比如按照dfs序从1开始:1,1+n/2,2,2+n/2,3,3+n/2…可以发现,每次要么是跳进子树,要么是跳出子树。
另一种理解是根据Hall定理,一定可以找到反复横跳的完备匹配。
根据上面的分析可以看出,
k
=
n
k=n
k=n时的答案也可以表示为
2
∗
∑
u
,
v
有
边
相
连
,
边
权
为
w
w
∗
m
i
n
(
s
i
z
[
v
]
,
n
−
s
i
z
[
v
]
)
2*\sum_{u,v有边相连,边权为w}w*min(siz[v],n-siz[v])
2∗∑u,v有边相连,边权为ww∗min(siz[v],n−siz[v])
用这种方法解决原问题可以用
f
[
i
]
[
j
]
f[i][j]
f[i][j]表示
i
i
i的子树中选了
j
j
j个点的最大距离和,转移时加上每条边的贡献,可以做到
O
(
n
k
)
O(nk)
O(nk)
看回原问题,我们相当于是要找到答案的k个点的重心。
考虑枚举重心,取离重心最远的k个点,且保证每个子树中的点不超过k/2个,直接这样做的复杂度是 O ( n 2 l o g n ) O(n^2logn) O(n2logn)
确定往哪边走时可以用nth_element确定前k大,所以问题的复杂度为
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)
关于 k k k为奇数的问题暂时没有理解,存疑
Code:
#include<bits/stdc++.h>
#define maxn 200005
#define LL long long
using namespace std;
char cb[1<<18],*cs,*ct;
#define getc() (cs==ct&&(ct=(cs=cb)+fread(cb,1,1<<18,stdin),cs==ct)?0:*cs++)
inline void read(int &a){
char c;while(!isdigit(c=getc()));
for(a=c-'0';isdigit(c=getc());a=a*10+c-'0');
}
int n,m,siz[maxn];
bool vis[maxn];
int fir[maxn],nxt[maxn<<1],to[maxn<<1],w[maxn<<1],tot;
inline void line(int x,int y,int z){nxt[++tot]=fir[x],fir[x]=tot,to[tot]=y,w[tot]=z;}
void Getroot(int u,int ff,int tsz,int &g){
siz[u]=1; bool flg=1;
for(int i=fir[u],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=ff)
Getroot(v,u,tsz,g),siz[u]+=siz[v],flg&=siz[v]<<1<=tsz;
if(flg&&(tsz-siz[u])<<1<=tsz) g=u;
}
struct node{
LL d;int p;
bool operator < (const node &t)const{return d>t.d;}
}a[maxn];
int cnt[maxn],num;
void dfs(int u,int ff,LL dis,int id){
a[++num]=(node){dis,id};
for(int i=fir[u],v;i;i=nxt[i]) if((v=to[i])!=ff) dfs(v,u,dis+w[i],id?id:v);
}
void TDC(int u,int tsz){
Getroot(u,0,tsz,u),vis[u]=1;
num=0,dfs(u,0,0,0),nth_element(a+1,a+m,a+1+num);
int v=0;
for(int i=1;i<=m;i++)
if(++cnt[a[i].p]>m>>1) {v=a[i].p;break;}
for(int i=fir[u];i;i=nxt[i]) cnt[to[i]]=0;
if(v&&!vis[v]) TDC(v,siz[v]<siz[u]?siz[v]:tsz-siz[u]);
else{
sort(a+1,a+1+num); LL ans=0;
for(int i=1,s=0;s<m;i++)
if(cnt[a[i].p]<m>>1) cnt[a[i].p]++,s++,ans+=a[i].d;
printf("%lld\n",ans<<1);
}
}
int main()
{
read(n),read(m);
for(int i=1,x,y,z;i<n;i++) read(x),read(y),read(z),line(x,y,z),line(y,x,z);
TDC(1,n);
}