原题链接:https://www.luogu.org/problemnew/show/P4149
虽然是道IOI的题,但是看起来好像并没有那么难。
对于这种树上的路径处理的问题,我们一般可以想到点分治的做法。
个人理解点分治的基本步骤:
- 先找到树的重心;
- 由重心统计所有从重心出发的路径,并进行相应的计算;
- 对重复的路径进行容斥;
- 对子树递归处理。
关键问题主要是解决第2步。
在这个题中,我们可以只对那些权值为k的路径记录他们的边的数量,和这个数量下的路径的数目(方便对于子树的容斥)。由于我们统计路径的时候是从重心出发的所有路径,因此我们需要两段路径拼接起来来形成一条完整的路径(点分治基本套路)。
在对权值为k的路径进行统计时,原来写的是:
//edge是所有从重心出发的边
for(int i=0,j=hd-1;i<j;i++){
while(edge[i].fi+edge[j].fi>k&&j>i){
j--;
}
if(edge[i].fi+edge[tj].fi==k){
mp[edge[i].se+edge[tj--].se]+=flag;
}
最后de了半天的bug(菜的真实),发现这样是不对的,会少计算权值一样的路径。
具体正解还是看下面的代码吧。
#include<bits/stdc++.h>
#define pii pair<int,int>
#define fi first
#define se second
using namespace std;
const int N=2e5+5;
const int inf=0x3f3f3f3f;
struct node{
int v,d,nxt;
}p[N<<1];
unordered_map<int,int>mp;
int head[N],tot,rt,sz,mn,tr,hd,k;
int mxson[N],sim[N],vis[N];
pii edge[N];
void add(int x,int y,int d){
p[tot].v=y,p[tot].d=d;
p[tot].nxt=head[x];
head[x]=tot++;
}
void init(){
tot=0;
memset(head,-1,sizeof(head));
}
void getroot(int x,int fa){//找重心
sim[x]=1,mxson[x]=0;
for(int i=head[x];~i;i=p[i].nxt){
int y=p[i].v;
if(y==fa||vis[y])
continue;
getroot(y,x);
sim[x]+=sim[y];
mxson[x]=max(sim[y],mxson[x]);
}
mxson[x]=max(sz-sim[x],mxson[x]);
if(mxson[x]<mn){
mn=mxson[x];
rt=x;
}
}
void dfs(int x,int fa,int d,int num){
edge[hd++]=pii(d,num);
for(int i=head[x];~i;i=p[i].nxt){
int y=p[i].v;
if(y==fa||vis[y])
continue;
dfs(y,x,d+p[i].d,num+1);
}
}
void solve(int x,int d,int flag){//找经过重心的路径
hd=0;
dfs(x,-1,d,flag==1?0:1);
sort(edge,edge+hd);
int cnt=0;
for(int i=0,j=hd-1;i<j;i++){
while(edge[i].fi+edge[j].fi>k&&j>i){
j--;
}
if(i!=0&&edge[i].fi==edge[i-1].fi){
if(edge[i].fi+edge[j].fi==k)
mp[edge[i].se+edge[j].se]+=flag*cnt;
}
else{
int tj=j;
while(edge[i].fi+edge[tj].fi==k){
mp[edge[i].se+edge[tj--].se]+=flag;
cnt++;
}
}
}
}
void dc(int x){
solve(x,0,1);
vis[x]=1;
for(int i=head[x];~i;i=p[i].nxt){
int y=p[i].v;
if(vis[y])
continue;
solve(y,p[i].d,-1);//容斥
mn=inf;
sz=sim[y];
getroot(y,x);
dc(rt);//对子树递归计算
}
}
int main(){
init();
int n,x,y,d;
scanf("%d%d",&n,&k);
for(int i=0;i<n-1;i++){
scanf("%d%d%d",&x,&y,&d);
add(x,y,d);
add(y,x,d);
}
mn=inf;
sz=n;
getroot(0,-1);
// cout<<rt<<endl;
dc(rt);
int ans=inf;
for(auto i:mp){
if(i.se>0){
ans=min(i.fi,ans);
}
}
printf("%d\n",ans==inf?-1:ans);
return 0;
}