分析
树的直径+二分+multiset有序多重集
(其实可以把multiset换成vector 不过要慢一些)
先用树形DP(或者两次DFS)求树的直径,作为二分的右边界r;树上最短的边作为二分的左边界l
二分最短赛道的长度mid=(l+r)/2
check(mid)函数中判断是否有至少m条赛道满足长度>=mid
如何判断?
把1号点作为根节点(选取任意一条点作为根节点都可以),向子树dfs,从叶节点向上递归。
举个例子:
假设当前节点为x,它的一个子节点为y。x,y构成的边长为w,dfs(y,x,k)表示以y为根的子树中与y相连的长度小于mid的最长链的长度
若dfs(y,x,k)+w>=mid 则ans++
否则将这一条连接x链的长度加到multiset集合s中
(因为这一条链与其他连接x的链相连后长度才有可能>=mid,这也解释了为什么dfs()返回的值总是小于mid)
接下来处理s集合中的链:
找出最短的一条。因为这个集合是从小到大排序,找出s.begin()指向的链即可。
再用lower_bound查找第一个长度>=(mid-*s.begin())的链。
若这样的两条链存在,则这两条链可以配对(长度和>=mid),ans++,并删除这两个数。
否则不断更新长度小于的mid的最长链,最后返回这个值。
感觉思路还挺简单的是不是? 考场上只能打暴力的蒟蒻两行泪
代码如下
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=50050,inf=1e9;
int n,m,head[N],tot=0,ans=0;
int d[N],v[N];
multiset<int>s[N];
multiset<int>::iterator it;
struct edge{
int ver,to,w;
}e[N*2];
ll read(){
ll sum=0,f=1;
char ch=getchar();
while(ch>'9'||ch<'0'){
if(ch=='-')f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9'){
sum=(sum<<3)+(sum<<1)+ch-'0';
ch=getchar();
}
return sum*f;
}
void add(int x,int y,int z){
e[++tot].ver=y;
e[tot].w=z;
e[tot].to=head[x];
head[x]=tot;
}
int dfs(int x,int pre,int k){
s[x].clear();
int now;
for(int i=head[x];i;i=e[i].to)
{
int y=e[i].ver;
if(y==pre)continue;
now=e[i].w+dfs(y,x,k);
if(now>=k)ans++;
else{
s[x].insert(now);
}
}
int maxi=0;
while(!s[x].empty()){
if(s[x].size()==1){
return max(maxi,*s[x].begin());
}
it=s[x].lower_bound(k-*s[x].begin());
if(it==s[x].begin()&&s[x].count(*it)==1){ it++;}
if(it==s[x].end()){
maxi=max(maxi,*s[x].begin());
s[x].erase(s[x].begin());
}
else{
ans++;
s[x].erase(it);
s[x].erase(s[x].begin());
}
}
return maxi;
}
int check(int k){
ans=0;
dfs(1,0,k);
if(ans>=m)return 1;
return 0;
}
int up=0;
void dp(int x){
v[x]=1;
for(int i=head[x];i;i=e[i].to){
int y=e[i].ver;
if(v[y])continue;
dp(y);
up=max(up,d[x]+d[y]+e[i].w);
d[x]=max(d[x],d[y]+e[i].w);
}
}
int main(){
// freopen("track.in","r",stdin);
// freopen("track.out","w",stdout);
n=read();
m=read();
int x,y,z,l=inf,r=0,mid,res;
for(int i=1;i<n;i++)
{
x=read();
y=read();
z=read();
if(z<l)l=z;
add(x,y,z);
add(y,x,z);
}
dp(1);
r=up;
while(l<=r){
int mid=l+(r-l)/2;
if(check(mid)){
res=mid;
l=mid+1;
}
else{
r=mid-1;
}
}
cout<<res;
return 0;
}