题目大意
给定一个n个节点的有根树。设 r(a,b) ,其中b是a的祖先,这表示b的后代中除a以外深度不大于a的节点个数。设 z(a)=∑r(a,b) 。求每个 z(a)
n≤500000
分析
答案等于:
∑dep(b)≤dep(a)dep(lca(a,b))−dep(a)
这等价于
z(father(a))+∑dep(b)=dep(a)dep(lca(a,b))−dep(a)
接下来我们考虑把深度相同的节点放在一起算。
如果先把这些节点按dfs序排序,那么对于第i个节点,它前面的节点与它的lca是单调的。
那么一个思路出来了:正着扫一遍,反过来又扫一遍;用单调栈维护lca深度单调递增的节点。如果相邻的节点lca相同那么把它们合并。
正确性:首先两个节点被合并后不会再分开;其次每个节点只会与它前一个节点合并,之后的合并操作我们认为是它前面的节点与其它节点合并。所以合并次数是
O(n)
的;另外不存在栈顶不需合并,但栈中间需要合并的情况。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int N=5e5+5,Log=19;
typedef long long LL;
int n,h[N],nxt[N],rt,fa[Log][N],Dep[N],Top,st[N],cnt[N],m,a[N],dep[N];
LL Ans[N],Now;
vector <int> H[N];
vector <int> ::iterator it;
char c;
int read()
{
int x=0,sig=1;
for (c=getchar();c<'0' || c>'9';c=getchar()) if (c=='-') sig=-1;
for (;c>='0' && c<='9';c=getchar()) x=x*10+c-48;
return x*sig;
}
void dfs(int x)
{
Dep[x]=Dep[fa[0][x]]+1; H[Dep[x]].push_back(x);
for (int i=h[x];i;i=nxt[i]) dfs(i);
}
int getlca(int x,int y)
{
for (int i=Log-1;i>=0;i--) if (fa[i][x]!=fa[i][y]) x=fa[i][x],y=fa[i][y];
return Dep[fa[0][x]];
}
int main()
{
n=read();
for (int i=1;i<=n;i++)
{
fa[0][i]=read();
if (!fa[0][i]) rt=i;else nxt[i]=h[fa[0][i]],h[fa[0][i]]=i;
}
dfs(rt);
for (int j=1;j<Log;j++)
for (int i=1;i<=n;i++) fa[j][i]=fa[j-1][fa[j-1][i]];
for (int i=1,j,k;!H[i].empty();i++)
{
Top=m=Now=0;
for (it=H[i].begin();it!=H[i].end();it++) a[++m]=*it;
for (j=1;j<=m;j++)
{
for (;Top>0;)
{
k=getlca(st[Top],a[j]);
if (k<dep[Top])
{
Now-=(LL)cnt[Top]*(dep[Top]-k); dep[Top]=k;
}
if (Top>1 && dep[Top-1]>=dep[Top])
{
cnt[Top-1]+=cnt[Top]; Now-=cnt[Top]*(dep[Top]-dep[Top-1]); Top--;
}else break;
}
Ans[a[j]]+=Now+Ans[fa[0][a[j]]]+Dep[a[j]]-1;
st[++Top]=a[j]; dep[Top]=Dep[a[j]]; Now+=dep[Top]; cnt[Top]=1;
}
Top=Now=0;
for (j=m;j;j--)
{
for (;Top>0;)
{
k=getlca(st[Top],a[j]);
if (k<dep[Top])
{
Now-=(LL)cnt[Top]*(dep[Top]-k); dep[Top]=k;
}
if (Top>1 && dep[Top-1]>=dep[Top])
{
cnt[Top-1]+=cnt[Top]; Now-=cnt[Top]*(dep[Top]-dep[Top-1]); Top--;
}else break;
}
Ans[a[j]]+=Now;
st[++Top]=a[j]; dep[Top]=Dep[a[j]]; Now+=dep[Top]; cnt[Top]=1;
}
}
for (int i=1;i<=n;i++) printf("%lld ",Ans[i]);
return 0;
}