给一个树,问有多少三元组满足两两距离相等。n<=100000
长链剖分应用之一:o(n)统计以深度为下标的可合并子树信息
在当前节点,令f(i)表示相对深度为i的节点个数,g(i)表示在子树外离当前点距离为i的点可以和子树内多少对点组成答案。
每次新来一个儿子,枚举长度,用当前g和儿子f以及当前f和儿子g更新一遍答案,然后用两边的f来更新g,再将儿子的f和g推入当前f和g。注意顺序问题。
初始时直接继承长儿子的信息,f数组相当于整体右移一位(指针左移一位),g数组相当于整体右移一位。一定要注意继承时也要考虑这个点自己和长儿子所在子树内组成的答案数(也就是继承后的g[0])
在dfs时给每条长链底端分配相当于长链长度的内存,注意g指针是每次右移的,右移了len次后还需要有len的空间,因此应分配2*len。
就可以o(n)解决这道题了。
#include<cstdio>
#define gm 100005
using namespace std;
typedef long long ll;
inline ll* __alloc(size_t size)
{
static ll pool[gm<<2];
static ll* ptr=pool;
ll* res=ptr;ptr+=size;
return res;
}
struct e
{
int t;
e *n;
e(int t,e *n):t(t),n(n){}
}*p[gm];
int dep[gm],son[gm],len[gm],fa[gm];
ll *F[gm],*G[gm];
void dfs(int x)
{
son[x]=x;
for(e *i=p[x];i;i=i->n)
{
if(i->t==fa[x]) continue;
fa[i->t]=x;
dep[i->t]=dep[x]+1;
dfs(i->t);
if(dep[son[i->t]]>dep[son[x]]) son[x]=son[i->t];
}
len[x]=dep[son[x]]-dep[x]+1;
for(e *i=p[x];i;i=i->n)
{
if(i->t==fa[x]) continue;
if(son[i->t]!=son[x])
{
int y=son[i->t];
F[y]=__alloc(len[i->t])+len[i->t]-1;
G[y]=__alloc(len[i->t]<<1);
}
}
if(x==1)
{
int y=son[x];
F[y]=__alloc(len[x])+len[x]-1;
G[y]=__alloc(len[x]<<1);
}
}
ll ans=0;
void DP(int x)
{
ll *&f=F[x],*&g=G[x];
for(e *i=p[x];i;i=i->n)
{
if(i->t==fa[x]) continue;
DP(i->t);
if(son[x]==son[i->t])
{
f=F[i->t]-1;
g=G[i->t]+1;
}
}
ans+=g[0];
++f[0];
for(e *i=p[x];i;i=i->n)
{
if(i->t==fa[x]||son[i->t]==son[x]) continue;
ll *fs=F[i->t],*gs=G[i->t];
ans+=fs[0]*g[1];
for(int w=1;w<len[i->t];++w)
ans+=fs[w]*g[w+1]+gs[w]*f[w-1];
g[1]+=fs[0]*f[1];
f[1]+=fs[0];
for(int w=1;w<len[i->t];++w)
{
g[w+1]+=fs[w]*f[w+1];
g[w-1]+=gs[w];
f[w+1]+=fs[w];
}
}
}
int n;
int main()
{
scanf("%d",&n);
for(int i=1;i<n;++i)
{
int u,v;
scanf("%d%d",&u,&v);
p[u]=new e(v,p[u]);
p[v]=new e(u,p[v]);
}
dfs(1);
DP(1);
printf("%lld\n",ans);
return 0;
}