题意
上小学的时候, A r e x t r e \sf Arextre Arextre 接触了 O I OI OI 。
植树节,小 A ( A r e x t r e \sf Arextre Arextre) 种了棵树。这棵树经 A r e x t r e \sf Arextre Arextre 之手种出,变得有了灵性,每当小 A 过掉一道黑题,它就会长高 1 nm 。小 A 上高中后,人们发现小 A 的树已经成了参天大树,人们在树底下生活、乘凉,受用了小 A 的树带来的无尽恩泽。由于小 A 还是个高中生,年纪尚轻,因此人们便叫他小EnZe。
小 A 的树由 n 个点 n-1 条长度不等的枝杈组成,树上两个点的距离定义为两个点通过树边的最短路长度。现在想知道在总共 n ( n − 1 ) 2 \frac{n(n-1)}{2} 2n(n−1) 条路径中,最长的前 K 条路径。你只需要输出这 K 条路径的长度。
n ≤ 2 e 5 , K ≤ min ( 2 e 5 , n ( n − 1 ) 2 ) , 1 ≤ w i ≤ 1 e 9 n\leq 2e5\;,\;K\leq \min(2e5\;,\;\frac{n(n-1)}{2})\;,\;1\leq w_i\leq 1e9 n≤2e5,K≤min(2e5,2n(n−1)),1≤wi≤1e9
2000 ms,512 mb
20%: n ≤ 2000 n\leq 2000 n≤2000
20%~40%: 保证树是一条链
题解
先讲讲官方正解吧。
首先对树进行点分治,对于每一个分治重心,bfs 求出分治区间
内所有点到它的距离,将这些距离以及对应的节点按照从小到大
存起来放到一个序列 q 中,对于每个点,记录下它是属于这个分
治重心的哪颗子树,对于 q 中每个位置,我们要存下这个位置之
前最近的一个跟当前点来自不同的子树的位置。
\;
接着,我们二分第 K 大的距离,设这个值为 Lim.
记录一个 cnt,初始为 0,枚举每个节点 x,在点分树上沿途查
找,对于一个分治重心 now,我们暴力在对应的 q 中用一个指针
扫一遍,找到与 x 距离大于等于 Lim 的点,并在途中给 cnt 累加,
如果 cnt 中途大于等于 K 了,那么就直接退出函数返回 K 。
最后返回 cnt。
\;
这样,做一次check的复杂度就是 O ( n log n + K ) O(n \log n + K) O(nlogn+K) 。
所以总的时间复杂度就是 O ( ( n log n + K ) log w m a x ) O((n \log n + K)\log w_{max}) O((nlogn+K)logwmax)。
是不是很详细 ?自我感觉二分时细节很多。
下面进入笔者的边分树做法。
如果我们能找到一种归类方法,把这 n ⋅ ( n − 1 ) / 2 n\cdotp(n-1)/2 n⋅(n−1)/2 条路经分到一些集合里,每个集合里的每条路径可以 O ( 1 ) O(1) O(1) 从该集合里某个比它大的路径得到,那么我们就可以用一个双端的有序数据结构维护,往里面插入每个集合的少量最大的路径,一旦该数据结构超过 K 个点,就把最小的删去。这应该是看到这道题最基本的思路之一:增量法。
那怎么归类路径呢?
这里有种好方法,边分树!我们建棵边分树,把每个点到它子树内的所有实点的距离处理出来,左边和右边分别存到两个数组 p l , p r pl,pr pl,pr 里面,然后按路径长度从大到小排序,然后用一个三元组 ( i , j , r o o t ) (i,j,root) (i,j,root) 表示一条路径, r o o t root root 表示该路径在边分树上的 l c a lca lca 编号, i i i 表示该路径的一个端点在 p l pl pl 中的编号(排序后), j j j 表示另一个端点在 p r pr pr 中的编号(排序后),这样一来, ( i , j , r o o t ) (i,j,root) (i,j,root) 可以从 ( i − 1 , j , r o o t ) (i-1,j,root) (i−1,j,root) 或 ( i , j − 1 , r o o t ) (i,j-1,root) (i,j−1,root) 枚过来,现在一个三元组和树上一条路径是一一对应的。可以注意到,如果是点分树就没有这样的性质了,因为不是二叉。
但是这里出现了一个问题,
(
i
,
j
,
r
o
o
t
)
(i,j,root)
(i,j,root) 的来源有两个,所以我们如果用不能去重的数据结构的话,就会
W
A
W\!A
WA ,所以我们可以用 set
或 map
来去重,在
(
i
,
j
,
r
o
o
t
)
(i,j,root)
(i,j,root) 转移到其它三元组之前把它去重,这样才能保证复杂度。由于 set
和 map
是即加即去重,所以长度并不会经历一个增加的过程,也就不会被误判超过 K 个点而误删。当然也可以强制让每个三元组只有一个来源,如 i,j 中较大的一个减一之类的,这样也可以不用 set
或 map
(但是你得能双端删除还是建议用这两个,除非你很生气,要手打非旋 Treap)
注:三元组的大于小于重载一定要写完,保证能够判重。
CODE
保留了部分分代码
#include<set>
#include<map>
#include<queue>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 400005
#define DB double
#define LL long long
#define ENDL putchar('\n')
#define SI set<cp>::iterator
#pragma GCC optimize(2)
LL read() {
LL f = 1,x = 0;char s = getchar();
while(s < '0' || s > '9') {if(s=='-')f = -f;s = getchar();}
while(s >= '0' && s <= '9') {x=x*10+(s-'0');s = getchar();}
return f * x;
}
int n,m,i,j,s,o,k;
int Abs(int x) {return x<0? -x:x;}
struct cp{
int a,b,lc;LL ds;cp(){a=b=lc=ds=0;}
cp(int A,int B,int Lc,LL D){a=A;b=B;lc=Lc;ds=D;}
};
bool operator < (cp A,cp B) {
if(A.ds != B.ds) return A.ds < B.ds;
if(A.lc != B.lc) return A.lc < B.lc;
if(A.a != B.a) return A.a > B.a;
return A.b > B.b;
}
bool operator > (cp A,cp B) {return B < A;}
set<cp> st;
struct it{
int v;LL w; it(){v=w=0;}
it(int V,LL W){v=V;w=W;}
};
LL bu[1000005],cntb;
vector<it> g[MAXN];
int de[MAXN];
void dfst1(int x,int fa,LL ds,int rt) {
if(ds && x < rt) bu[++ cntb] = ds;
for(int i = 0;i < (int)g[x].size();i ++) {
if(g[x][i].v != fa) {
dfst1(g[x][i].v,x,ds + g[x][i].w,rt);
}
}return ;
}
int fa[MAXN],sn[MAXN],tl;
LL dis[MAXN];
void dfs0(int x,int ff) {
fa[x] = ff;sn[x] = 0;
for(int i = 0;i < (int)g[x].size();i ++) {
if(g[x][i].v != ff) {
sn[x] = g[x][i].v;
dis[g[x][i].v] = dis[x] + g[x][i].w;
dfs0(g[x][i].v,x);
}
}
if(!sn[x]) tl = x;
return ;
}
bool ok[MAXN];
void dfsbd0(int x,int fa) {
for(int i = 0;i < (int)g[x].size();i ++) {
if(g[x][i].v != fa) {
dfsbd0(g[x][i].v,x);
}else g[x].erase(g[x].begin()+i),i--;
}return ;
}
void dfsbd1(int x) {
if(g[x].size() > 3) {
int p1 = ++ n,p2 = ++ n,tm = 0;
for(int i = 0;i < (int)g[x].size();i ++) {
tm ^= 1;
if(tm) g[p1].push_back(g[x][i]);
else g[p2].push_back(g[x][i]);
}
g[x].clear();
g[x].push_back(it(p1,0));
g[x].push_back(it(p2,0));
}
for(int i = 0;i < (int)g[x].size();i ++) {
dfsbd1(g[x][i].v);
}return ;
}
void dfsbd2(int x,int fa,int fe) {
if(fa) g[x].push_back(it(fa,fe));
for(int i = 0;i < (int)g[x].size();i ++) {
if(g[x][i].v != fa) {
dfsbd2(g[x][i].v,x,g[x][i].w);
}
}return ;
}
bool vs[MAXN];
int siz[MAXN],ms[MAXN],SIZ,c1,c2,ce;
int lp[MAXN],rp[MAXN],le[MAXN],ls[MAXN],rs[MAXN];
int cnt;
vector<it> gl[MAXN],gr[MAXN];
bool cmp(it a,it b) {return a.w > b.w;}
void dfsi(int x,int fa,int fe) {
siz[x] = 1; ms[x] = 0;
for(int i = 0;i < (int)g[x].size();i ++) {
int y = g[x][i].v;
if(y != fa && !vs[y]) {
dfsi(y,x,g[x][i].w);
siz[x] += siz[y];
}
}
if(fa && Abs(SIZ-(siz[x]<<1)) <= Abs(SIZ-(siz[c1]<<1))) {
c1 = x;c2 = fa;ce = fe;
}
return ;
}
void adde(int x,int fa,LL ds,int rt) {
if(ok[x]) {
if(rt > 0) gl[rt].push_back(it(x,ds));
else gr[-rt].push_back(it(x,ds));
}
for(int i = 0;i < (int)g[x].size();i ++) {
int y = g[x][i].v;
if(y != fa && !vs[y]) {
adde(y,x,ds+g[x][i].w,rt);
}
}return ;
}
int solve(int a,int b,int wi) {
int x = ++ cnt;
lp[x] = a; rp[x] = b; le[x] = wi;
adde(a,b,0ll,x); adde(b,a,0ll,-x);
if(!gl[x].empty() && !gr[x].empty()) {
sort(gl[x].begin(),gl[x].end(),cmp);
sort(gr[x].begin(),gr[x].end(),cmp);
st.insert(cp(0,0,x,gl[x][0].w+gr[x][0].w+wi));
}
dfsi(a,b,wi); dfsi(b,a,wi);
SIZ = siz[a]; c1 = c2 = ce = 0; vs[b] = 1; vs[a] = 0;
dfsi(a,0,0); if(c1) ls[x] = solve(c1,c2,ce);
SIZ = siz[b]; c1 = c2 = ce = 0; vs[a] = 1; vs[b] = 0;
dfsi(b,0,0); if(c1) rs[x] = solve(c1,c2,ce);
return x;
}
int main() {
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
n = read();m = read();
int ct1 = 0;
for(int i = 1;i < n;i ++) {
s = read(); o = read(); k = read();
g[s].push_back(it(o,k));
g[o].push_back(it(s,k));
de[s] ++; de[o] ++;
if(de[s] == 1) ct1 ++; if(de[o] == 1) ct1 ++;
if(de[s] == 2) ct1 --; if(de[o] == 2) ct1 ++;
}
for(int i = 1;i <= n;i ++) random_shuffle(g[i].begin(),g[i].end()),ok[i] = 1;
if(n <= 2000) {
cntb = 0;
for(int i = 1;i <= n;i ++) {
dfst1(i,0,0ll,i);
}
sort(bu + 1,bu + 1 + cntb);
for(int i = cntb;i > 0 && i >= cntb-m+1;i --) {
printf("%lld\n",bu[i]);
}
return 0;
}
if(ct1 <= 2) {
int rt = 0;
for(int i = 1;i <= n;i ++) {
if(de[i] <= 1) rt = i;
}
dfs0(rt,0);
st.insert(cp(rt,tl,rt,dis[tl]));
for(int i = 1;i <= m;i ++) {
bu[i] = st.rbegin()->ds;
cp t = *st.rbegin(); st.erase(*st.rbegin());
cp t2 = t; t.a = t.lc = sn[t.a]; t.ds = dis[t.b]-dis[t.a];
t2.b = fa[t2.b]; t2.ds = dis[t2.b] - dis[t2.a];
st.insert(t); st.insert(t2);
while((int)st.size() > m) st.erase(st.begin());
}
for(int i = 1;i <= m;i ++) {
printf("%lld\n",bu[i]);
}
return 0;
}
dfsbd0(1,0);
dfsbd1(1);
dfsbd2(1,0,0);
SIZ = n; c1 = c2 = ce = 0;
dfsi(1,0,0);
solve(c1,c2,ce);
for(int i = 1;i <= m;i ++) {
bu[i] = st.rbegin()->ds;
cp t = *st.rbegin(); st.erase(t);
cp t2 = t;
if(t.a < (int)gl[t.lc].size()-1) {
t.a ++; t.ds = gl[t.lc][t.a].w+gr[t.lc][t.b].w+le[t.lc];
st.insert(t);
}
if(t2.b < (int)gr[t2.lc].size()-1) {
t2.b ++; t2.ds = gl[t2.lc][t2.a].w+gr[t2.lc][t2.b].w+le[t2.lc];
st.insert(t2);
}
while((int)st.size() > m) st.erase(st.begin());
}
for(int i = 1;i <= m;i ++) {
printf("%lld\n",bu[i]);
}
return 0;
}