题目大意:
求树上所有路径中的前k长路。
题解:
我们把这棵树的点分治序处理出来。假设我们确定了一个分治中心下的一条链,我们需要找到另一条链使得两条加起来最大。
那么另外一条可行链的端点在点分治序上一定形成一段区间。然后就变成了对于一个右端点都有一段可行的左端点,要求两点权值和最大。
之后就变成了BZOJ2006: [NOI2010]超级钢琴。点分治序+st表+堆。
代码:
#include<cstdio>
#include<algorithm>
#include<queue>
#define pr pair<int,int>
#define prr pair<pr,pr>
#define mp make_pair
#define fr first
#define sc second
using namespace std;
int n,k,cnt,num,root,vis[1000005],sz[1000005],f[1000005],last[1000005],l[1000005],r[1000005],st[1000005][21],a[1000005],lg[1000005];
priority_queue<prr> q;
struct node{
int to,next,val;
}e[100005];
void add(int a,int b,int c){
e[++cnt].to=b;
e[cnt].next=last[a];
e[cnt].val=c;
last[a]=cnt;
}
void getroot(int x,int fa){
sz[x]=1,f[x]=0;
for (int i=last[x]; i; i=e[i].next){
int V=e[i].to;
if (V==fa || vis[V]) continue;
getroot(V,x);
sz[x]+=sz[V];
f[x]=max(f[x],sz[V]);
}
f[x]=max(f[x],num-sz[x]);
if (f[root]>f[x]) root=x;
}
void getdis(int x,int fa,int dep){
a[++cnt]=dep,l[cnt]=l[cnt-1];
if (!r[cnt]) r[cnt]=r[cnt-1];
for (int i=last[x]; i; i=e[i].next){
int V=e[i].to;
if (vis[V] || V==fa) continue;
getdis(V,x,dep+e[i].val);
}
}
void dfs(int x){
vis[x]=1;
a[++cnt]=0,l[cnt]=cnt,r[cnt]=cnt-1;
for (int i=last[x]; i; i=e[i].next){
int V=e[i].to;
if (vis[V]) continue;
r[cnt+1]=cnt;
getdis(V,x,e[i].val);
}
for (int i=last[x]; i; i=e[i].next){
int V=e[i].to;
if (vis[V]) continue;
num=sz[V];
root=0;
getroot(V,x);
dfs(root);
}
}
int calc(int x,int y){
if (a[x]>a[y]) return x;
else return y;
}
int query(int a,int b){
if (a>b) return 0;
int len=lg[b-a+1];
return calc(st[a][len],st[b-(1<<len)+1][len]);
}
int main(){
scanf("%d%d",&n,&k);
for (int i=1; i<n; i++){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
add(y,x,z);
}
num=n;
f[root]=1e9;
getroot(1,0);
dfs(root);
for (int i=1; i<=cnt; i++) st[i][0]=i;
for (int i=2; i<=cnt; i++) lg[i]=lg[i>>1]+1;
for (int j=1; (1<<j)<=cnt; j++)
for (int i=1; i+(1<<j)-1<=cnt; i++)
st[i][j]=calc(st[i][j-1],st[i+(1<<j-1)][j-1]);
for (int i=1; i<=cnt; i++){
if (l[i]>r[i]) continue;
q.push(mp(mp(a[i]+a[query(l[i],r[i])],i),mp(l[i],r[i])));
}
for (int i=1; i<=k; i++){
printf("%d\n",q.top().fr.fr);
int x=q.top().fr.sc,aa=q.top().sc.fr,bb=q.top().sc.sc,y=query(aa,bb);
q.pop();
int id1=query(aa,y-1);
int id2=query(y+1,bb);
if (id1) q.push(mp(mp(a[x]+a[id1],x),mp(aa,y-1)));
if (id2) q.push(mp(mp(a[x]+a[id2],x),mp(y+1,bb)));
}
return 0;
}