题目
题目描述
对于一颗有根树,记
f
(
x
,
i
)
f(x,i)
f(x,i) 为
x
x
x 的子树中到
x
x
x 距离为
i
i
i 的点的数量。
对于所有 f ( x , i ) f(x,i) f(x,i),求前 k k k 大的值。
数据范围与提示
n
≤
3
×
1
0
6
n\le 3\times 10^6
n≤3×106 且
k
≤
1
0
18
k\le 10^{18}
k≤1018 。
思路
本来以为是二分,其实根本不需要二分,因为所有的结果是真的可以直接存下来的——桶计数。
显然是长链剖分。特性是,比最长的短链更长、比长链更短,这些值是没有任何改变的。那么我们可以用一个懒标记,表示这个桶内的每个值都应该考虑 a l l all all 次。不过比最长的短链更短时,它们不应该受到这个全局懒标记的影响,可以用 t a g i tag_i tagi 来修正,比如实际上每个值只计算 a l l − t a g i all-tag_i all−tagi 次。
然后就没了。时间复杂度 O ( n ) \mathcal O(n) O(n) 。
代码
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
typedef long long int_;
# define rep(i,a,b) for(int i=(a); i<=(b); ++i)
# define drep(i,a,b) for(int i=(a); i>=(b); --i)
inline int readint(){
int a = 0; char c = getchar(), f = 1;
for(; c<'0'||c>'9'; c=getchar())
if(c == '-') f = -f;
for(; '0'<=c&&c<='9'; c=getchar())
a = (a<<3)+(a<<1)+(c^48);
return a*f;
}
const int MaxN = 3000005;
struct Edge{
int to, nxt; Edge(){}
Edge(int T,int N){
to = T, nxt = N;
}
};
Edge e[MaxN];
int head[MaxN], cntEdge;
void addEdge(int a,int b){
e[cntEdge] = Edge(b,head[a]);
head[a] = cntEdge ++;
}
int heit[MaxN], son[MaxN];
void getInfo(int x){
son[x] = heit[0] = 0;
for(int i=head[x]; ~i; i=e[i].nxt){
getInfo(e[i].to);
if(heit[e[i].to] > heit[son[x]])
son[x] = e[i].to;
}
heit[x] = heit[son[x]]+1;
}
int dp[MaxN], tag[MaxN], all[MaxN];
long long buc[MaxN]; // bucket sort
int zz[MaxN]; // pointer
void relieve(int x,int len){
for(int i=0; i<len; ++i){
buc[dp[zz[x]+i]] += all[x]-tag[zz[x]+i];
tag[zz[x]+i] = all[x]; // erased
}
}
void dfs(int x){
if(heit[x] == 1){
dp[zz[x]] = all[x] = 1;
tag[zz[x]] = 0; return ;
}
zz[son[x]] = zz[x]+1;
dfs(son[x]); dp[zz[x]] = 1;
tag[zz[x]] = all[x] = all[son[x]];
int len = 0; // how long is relieved
for(int i=head[x],y; ~i; i=e[i].nxt)
if(e[i].to != son[x]){
zz[e[i].to] = zz[x]+heit[x];
dfs(y = e[i].to);
len = max(len,heit[y]);
relieve(x,heit[y]+1);
relieve(y,heit[y]);
rep(j,0,heit[y]-1) // update
dp[zz[x]+j+1] += dp[zz[y]+j];
}
++ all[x]; // long-son and itself
rep(j,0,len) tag[zz[x]+j] = all[x]-1;
}
int main(){
int n = readint();
long long k; scanf("%lld",&k);
rep(i,1,n) head[i] = -1;
rep(i,2,n)
addEdge(readint(),i);
getInfo(1), dfs(1);
relieve(1,heit[1]);
long long ans = 0;
drep(i,n,1) // every kind of value
if(k >= buc[i])
ans += i*buc[i], k -= buc[i];
else{ ans += 1ll*i*k; break; }
printf("%lld\n",ans);
return 0;
}