题意
给一棵树,从中选三个点,使得三个点两两间距离相等,求方案数。
题解
对每一个结点,用
n
u
m
[
u
]
[
d
]
num[u][d]
num[u][d]表示子树中到当前结点u的距离为d的节点数,用
w
a
y
[
u
]
[
d
]
way[u][d]
way[u][d]表示已经有很多两个结点的配对,再添加一个到当前结点距离为d的结点即可构成一个方案的结点对数。
枚举子节点v,先计算答案
A
n
s
+
=
w
a
y
[
u
]
[
d
+
1
]
×
n
u
m
[
v
]
[
d
]
+
w
a
y
[
v
]
[
d
]
×
n
u
m
[
u
]
[
d
−
1
]
Ans+=way[u][d+1]\times num[v][d]+way[v][d]\times num[u][d-1]
Ans+=way[u][d+1]×num[v][d]+way[v][d]×num[u][d−1],然后再将v的状态加到u里面去。
w
a
y
[
u
]
[
d
+
1
]
+
=
n
u
m
[
v
]
[
d
]
∗
n
u
m
[
u
]
[
d
+
1
]
way[u][d+1]+=num[v][d]*num[u][d+1]
way[u][d+1]+=num[v][d]∗num[u][d+1],即将v的结点与u已经存在的配对。
w
a
y
[
u
]
[
d
−
1
]
+
=
w
a
y
[
v
]
[
d
]
way[u][d-1]+=way[v][d]
way[u][d−1]+=way[v][d],直接把v的方案加到u上
n
u
m
[
u
]
[
d
+
1
]
+
=
n
u
m
[
v
]
[
d
]
num[u][d+1]+=num[v][d]
num[u][d+1]+=num[v][d],直接把v的结点加到u上
显然可以用启发式合并,因为每次合并与最大深度有关,显然选择深度最大的子节点作为重儿子最优。
如何直接利用重儿子的way和num?
可以发现当前结点的way为重儿子左移一位的结果,num为重儿子右移一位的结果。
十分巧妙的方法:巧妙地开way和num的数组空间,使得当前结点可以重复利用重儿子的数组。把way[u]的起始位置设为way[v]+1,num[u]的起始位置设为num[u]-1。
关于时间复杂度。实际上是
O
(
n
)
O(n)
O(n)
每次合并时,为轻儿子的最大深度之和,即轻儿子的重链长度之和(重链总是通向最深的结点的),重链长度之和刚好等于节点数。所以合并总复杂度为
O
(
n
)
O(n)
O(n) ,重儿子无需合并。
代码
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
const int MAXN=100005;
struct Edge
{
int v;
Edge *nxt;
};
struct Graph
{
Edge edges[MAXN*2],*adj[MAXN],*edit;
void Init()
{
memset(adj,0,sizeof adj);
edit=edges;
}
void AddEdge(int u,int v)
{
edit->v=v;
edit->nxt=adj[u];
adj[u]=edit++;
}
};
int n;
long long ans;
Graph Tr;
int len[MAXN],dep[MAXN],mxlen[MAXN],son[MAXN],fa[MAXN];
long long pool[MAXN*3],*pl_it;
long long *num[MAXN],*way[MAXN];
void PreDFS(int u)
{
mxlen[u]=1;son[u]=0;
for(Edge *e=Tr.adj[u];e;e=e->nxt)
{
int v=e->v;
if(v==fa[u])
continue;
fa[v]=u;
dep[v]=dep[u]+1;
PreDFS(v);
mxlen[u]=max(mxlen[u],mxlen[v]+1);
if(mxlen[v]>mxlen[son[u]])
son[u]=v;
}
}
void dfs1(int u)
{
num[u]=pl_it++;
len[u]=1;
if(son[u])
dfs1(son[u]),len[u]+=len[son[u]];
for(Edge *e=Tr.adj[u];e;e=e->nxt)
if(e->v!=fa[u]&&e->v!=son[u])
dfs1(e->v);
}
void dfs2(int u)
{
for(Edge *e=Tr.adj[u];e;e=e->nxt)
if(e->v!=fa[u]&&e->v!=son[u])
dfs2(e->v),pl_it=way[e->v]+len[e->v];
if(son[u])
dfs2(son[u]);
way[u]=pl_it++;
}
void Solve(int u)
{
if(son[u])
Solve(son[u]);
num[u][0]=1;
ans+=way[u][0];
for(Edge *e=Tr.adj[u];e;e=e->nxt)
{
int v=e->v;
if(v==fa[u]||v==son[u])
continue;
Solve(v);
for(int d=0;d<len[v];d++)
{
ans+=way[u][d+1]*num[v][d];
if(d-1>=0)
ans+=way[v][d]*num[u][d-1];
}
for(int d=0;d<len[v];d++)
{
way[u][d+1]+=num[v][d]*num[u][d+1];
if(d-1>=0)
way[u][d-1]+=way[v][d];
num[u][d+1]+=num[v][d];
}
}
}
int main()
{
scanf("%d",&n);
Tr.Init();
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
Tr.AddEdge(u,v);
Tr.AddEdge(v,u);
}
PreDFS(1);
ans=0;
memset(pool,0,sizeof(long long)*n*3);
pl_it=pool;
dfs1(1);dfs2(1);
Solve(1);
printf("%lld\n",ans);
return 0;
}