题意
给出一棵n个节点的树,定义一个点u是深度为m的k-ary heap当且仅当:
当m=1时u是一个深度为1的k-ary heap
当m>1时则要满足有不小于k个u的儿子v满足v是深度为m-1的k-ary heap
设
dp(k,u)
d
p
(
k
,
u
)
表示u的子树中深度最大的k-ary heap的深度。
现在要求
∑nk=1∑nu=1dp(k,u)
∑
k
=
1
n
∑
u
=
1
n
d
p
(
k
,
u
)
n<=300000
n
<=
300000
分析
设
h(k,u)
h
(
k
,
u
)
表示最大的m满足u是深度为m的k-ary heap。
首先考虑暴力要怎么求,先枚举k,x的子树处理完后,x的深度就是儿子深度的第k大+1。
这样复杂度是
O(n2)
O
(
n
2
)
。
当
k=1
k
=
1
时我们可以这样
O(n)
O
(
n
)
做。当
k>1
k
>
1
时有如下性质:
h(k,u)≤lognk
h
(
k
,
u
)
≤
l
o
g
k
n
h(k,u)≥h(k+1,u)
h
(
k
,
u
)
≥
h
(
k
+
1
,
u
)
设
f(u,j)
f
(
u
,
j
)
表示最大的k,满足
h(k,u)=j
h
(
k
,
u
)
=
j
,显然状态数只有
O(nlogn)
O
(
n
l
o
g
n
)
个。
考虑如何求
f(u,j)
f
(
u
,
j
)
。
先把u的子树处理完,然后枚举j,把所有的
f(v,j−1)
f
(
v
,
j
−
1
)
降序排序,其中v是u的儿子。然后找到一个最大的k,满足
f(v,j−1)
f
(
v
,
j
−
1
)
中的第k大不小于k,那么这个k就是
f(u,j)
f
(
u
,
j
)
。
接下来可以从大到小枚举k,然后把所有
f(u,j)=k
f
(
u
,
j
)
=
k
的位置取出,更新dp数组。不难发现更新次数不超过
O(nlogn)
O
(
n
l
o
g
n
)
。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<vector>
#define mp(x,y) make_pair(x,y)
using namespace std;
typedef long long LL;
typedef pair<int,int> pi;
const int N=300005;
int n,cnt,last[N],dep[N],f[N][25],a[N],dp[N],fa[N];
struct edge{int to,next;}e[N*2];
vector<pi> vec[N];
LL ans;
void addedge(int u,int v)
{
e[++cnt].to=v;e[cnt].next=last[u];last[u]=cnt;
e[++cnt].to=u;e[cnt].next=last[v];last[v]=cnt;
}
bool cmp(int x,int y)
{
return x>y;
}
void dfs(int x)
{
dep[x]=1;f[x][1]=n;
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=fa[x]) fa[e[i].to]=x,dfs(e[i].to),dep[x]=max(dep[x],dep[e[i].to]+1);
ans+=dep[x];
for (int i=2;i<20;i++)
{
int tot=0,k=0;
for (int j=last[x];j;j=e[j].next) if (e[j].to!=fa[x]) a[++tot]=f[e[j].to][i-1];
if (!tot) break;
sort(a+1,a+tot+1,cmp);
while (k<tot&&a[k+1]>k) k++;
f[x][i]=k;vec[k].push_back(mp(x,i));
}
}
int main()
{
scanf("%d",&n);
for (int i=1;i<n;i++)
{
int x,y;scanf("%d%d",&x,&y);
addedge(x,y);
}
dfs(1);
int sum=n;
for (int i=1;i<=n;i++) dp[i]=1;
for (int k=n;k>1;k--)
{
for (int j=0;j<vec[k].size();j++)
{
int x=vec[k][j].first,y=vec[k][j].second;
while (x&&dp[x]<y) sum+=y-dp[x],dp[x]=y,x=fa[x];
}
ans+=sum;
}
printf("%I64d",ans);
return 0;
}