题目链接:点我啊╭(╯^╰)╮
题目大意:
给出一颗根为
1
1
1 的树,自由选择
K
K
K 个关键点
点
x
x
x 的最远距离为
x
x
x 到根的路径上遇到的第一个关键点的距离
整颗树的权值为所有最远距离的最小值
求
K
=
1
,
2...
n
K = 1,2...n
K=1,2...n 的所有权值和
解题思路:
考虑单个
K
K
K 咋求
可以二分答案
m
i
d
mid
mid ,每次找到深度最大的那个节点
x
x
x
然后要求最远距离不超过
m
i
d
mid
mid,因此
x
x
x 向上跳
m
i
d
mid
mid 步到达
y
y
y
则
y
y
y 为一个关键点,删除
y
y
y 的子树,继续找最大节点
直到全部删完,判断关键点的个数
那么就可以按照
d
f
s
dfs
dfs序建线段树
深度最大就直接找,删除子树就可以区间删
时间复杂度:
O
(
n
l
o
g
2
n
)
O(nlog^2n)
O(nlog2n)
可以发现若答案为
x
x
x,则关键点的个数为
n
x
+
1
\frac{n}{x} + 1
xn+1
假设
f
(
x
)
=
k
f(x) = k
f(x)=k 为答案为
x
x
x 时关键点个数为
k
k
k
然后就不二分答案了,直接枚举答案为
0
~
n
0~n
0~n
求得
f
(
x
)
f(x)
f(x) 的所有值,那么
a
n
s
[
f
(
x
)
]
=
x
ans[f(x)] = x
ans[f(x)]=x
然后对
a
n
s
ans
ans 求个前缀
m
i
n
min
min 即可
考虑时间复杂度,因为是枚举答案
0
~
n
0~n
0~n
因此关键点的总数为
n
1
+
n
2
+
n
3
+
.
.
.
+
n
n
=
n
l
o
g
n
\frac{n}{1} + \frac{n}{2} + \frac{n}{3} + ...+ \frac{n}{n} = nlogn
1n+2n+3n+...+nn=nlogn
线段树上的操作都是对关键点进行操作
因此总的时间复杂度为:
O
(
n
l
o
g
2
n
)
O(nlog^2n)
O(nlog2n)
#include<bits/stdc++.h>
#define rint register int
#define deb(x) cerr<<#x<<" = "<<(x)<<'\n';
using namespace std;
typedef long long ll;
typedef pair <int,int> pii;
const int maxn = 2e5 + 5;
int n, ans[maxn], tot;
int in[maxn], out[maxn], dfn[maxn];
int fa[maxn][21], dep[maxn];
int t[maxn<<2], lz[maxn<<2], tt[maxn<<2];
vector <int> g[maxn], vt;
void dfs(int u, int f) {
dep[u] = dep[f] + 1;
in[u] = ++tot, dfn[tot] = u;
fa[u][0] = f;
for(int i=1; i<=20; i++)
fa[u][i] = fa[fa[u][i-1]][i-1];
for(auto v : g[u]) {
if(v == f) continue;
dfs(v, u);
}
out[u] = tot;
}
inline void pushup(int rt) {
if(dep[t[rt<<1]] > dep[t[rt<<1|1]]) t[rt] = t[rt<<1];
else t[rt] = t[rt<<1|1];
}
void build(int l, int r, int rt) {
lz[rt] = 0;
if(l == r) {
t[rt] = tt[rt] = dfn[l];
return;
}
int mid = l + r >> 1;
build(l, mid, rt<<1);
build(mid+1, r, rt<<1|1);
pushup(rt); tt[rt] = t[rt];
}
inline void pushdown(int rt) {
if(!lz[rt]) return;
vt.push_back(rt<<1), vt.push_back(rt<<1|1);
t[rt<<1] = t[rt<<1|1] = 0;
lz[rt<<1] = lz[rt<<1|1] = 1;
lz[rt] = 0;
}
void update(int L, int R, int l, int r, int rt) {
vt.push_back(rt);
if(l>R || r<L) return;
if(l>=L && r<=R) {
t[rt] = 0, lz[rt] = 1;
return;
}
pushdown(rt);
int mid = l + r >> 1;
update(L, R, l, mid, rt<<1);
update(L, R, mid+1, r, rt<<1|1);
pushup(rt);
}
inline int get(int u, int num) {
for(int i=20; ~i; --i)
if(num >= (1 << i))
num -= (1 << i), u = fa[u][i];
return u ? u : 1;
}
inline int gao(int x) {
vt.clear();
int ret = 0, now, fnow;
while(true) {
now = t[1];
if(now == 0) break;
fnow = get(now, x);
update(in[fnow], out[fnow], 1, n, 1);
++ret;
}
for(auto i : vt) t[i] = tt[i], lz[i] = 0;
return ret;
}
signed main() {
while(~scanf("%d", &n)) {
tot = 0;
for(int i=1; i<=n; ++i) g[i].clear(), ans[i] = n + 1;
for(int i=2, x; i<=n; ++i) {
scanf("%d", &x);
g[x].push_back(i);
}
dfs(1, 0);
build(1, n, 1);
for(int i=n; ~i; --i) ans[gao(i)] = i;
for(int i=2; i<=n; ++i) ans[i] = min(ans[i-1], ans[i]);
ll ret = 0;
for(int i=1; i<=n; ++i) ret += ans[i];
printf("%lld\n", ret);
}
}