题目链接
先考虑n ^ 2如何做
不是直接暴力dfs
设 f [ u ] [ i ] f[u][i] f[u][i]为 u u u的子树中,到 u u u距离为i的点有多少个, g [ u ] [ i ] g[u][i] g[u][i]表示在 u u u的子树中,有两个点到他们的 l c a lca lca的距离为 d d d,并且 u u u到他们的 l c a lca lca的距离为 d − i d - i d−i的方案数有多少个
考虑动态地去计算这两个东西 像树形背包那样
显然有
g
[
u
]
[
i
]
+
=
f
[
v
]
[
i
−
1
]
∗
f
[
u
]
[
i
]
g[u][i] += f[v][i - 1] * f[u][i]
g[u][i]+=f[v][i−1]∗f[u][i]
g
[
u
]
[
i
]
+
=
g
[
v
]
[
i
+
1
]
g[u][i] += g[v][i + 1]
g[u][i]+=g[v][i+1]
f
[
u
]
[
i
]
+
=
f
[
v
]
[
i
−
1
]
f[u][i] += f[v][i - 1]
f[u][i]+=f[v][i−1]
第一个式子表示以 u u u为 l c a lca lca的方案数有多少个,剩下两个式子直接继承儿子的答案
计算答案也是动态更新的
当每次扫到一个儿子时
a
n
s
+
=
f
[
u
]
[
i
]
∗
g
[
v
]
[
i
+
1
]
+
g
[
u
]
[
i
]
∗
f
[
v
]
[
i
−
1
]
ans += f[u][i] *g[v][i + 1] + g[u][i] * f[v][i - 1]
ans+=f[u][i]∗g[v][i+1]+g[u][i]∗f[v][i−1]
表示三个点里有两个点在
v
v
v这个子树中的方案数和一个点在
v
v
v这个子树中的方案数
然后n ^ 2的就搞定了
考虑怎么优化
上个神奇的东西:长链剖分
考虑每次选到叶节点距离最大的儿子当重儿子,然后把重儿子的答案直接继承过来,轻儿子的暴力统计
这样做是O(n)的
因为轻儿子的暴力统计部分的复杂度是O(轻儿子的链长)
而长链剖分保证了O(
∑
轻
儿
子
链
长
\sum 轻儿子链长
∑轻儿子链长) = O(n)
接下来讲讲具体怎么从重儿子继承
考虑类似指针搞搞,就是注意到每次继承的时候,都是
f
[
u
]
[
i
]
=
f
[
s
o
n
[
u
]
]
[
i
−
1
]
,
g
[
u
]
[
i
]
=
g
[
s
o
n
[
u
]
]
[
i
+
1
]
f[u][i] = f[son[u]][i - 1] , g[u][i] = g[son[u]][i + 1]
f[u][i]=f[son[u]][i−1],g[u][i]=g[son[u]][i+1]
那其实只用指针移位一下就可以了
不过这样貌似有点麻烦
更简单的实现是用指针分配连续空间,也就是开一个 t m p tmp tmp数组用来存 f f f和 g g g的值(把一条长链上的所有点都分配在一起
遇到一个长链的顶端就分配给他他所在的长链长度*2的空间(*2是因为 g g g数组在继承时会向后移位),然后类似指针移位再分配空间就好了
这样写起来比较舒服
PS.貌似这题其实可以重儿子随便选?(不过比较不好写就是了
代码:
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<vector>
#include<queue>
#include<ctime>
#include<stack>
#include<cmath>
using namespace std;
typedef long long LL;
const LL INF = 100000000000000LL;
const int maxn = 100010;
vector<int> e[maxn];
int n,len[maxn],son[maxn];
LL *f[maxn],*g[maxn],tmp[maxn * 6],*it = tmp,ans;
inline LL getint()
{
LL ret = 0,f = 1;
char c = getchar();
while (c < '0' || c > '9')
{
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
ret = ret * 10 + c - '0',c = getchar();
return ret * f;
}
inline void dfs(int u,int fa)
{
for (int i = 0; i < e[u].size(); i++)
{
int v = e[u][i];
if (v == fa) continue;
dfs(v,u);
if (len[v] > len[son[u]]) son[u] = v;
}
len[u] = len[son[u]] + 1;
}
inline void dp(int u,int fa)
{
if (son[u]) f[son[u]] = f[u] + 1 , g[son[u]] = g[u] - 1 , dp(son[u],u);
f[u][0] = 1; ans += g[u][0];
for (int i = 0; i < e[u].size(); i++)
{
int v = e[u][i];
if (v == fa || v == son[u]) continue;
f[v] = it; it += (len[v] << 1) + 1; g[v] = it; it += (len[v] << 1) + 1;
dp(v,u);
for (int j = 0; j <= len[v]; j++)
{
if (j) ans += g[u][j] * f[v][j - 1];
if (j < len[v]) ans += f[u][j] * g[v][j + 1];
}
for (int j = 0; j <= len[v]; j++)
{
if (j) g[u][j] += f[v][j - 1] * f[u][j];
if (j < len[v]) g[u][j] += g[v][j + 1];
if (j) f[u][j] += f[v][j - 1];
}
}
}
int main()
{
#ifdef AMC
freopen("AMC1.txt","r",stdin);
freopen("AMC2.txt","w",stdout);
#endif
n = getint();
for (int i = 1; i <= n - 1; i++)
{
int u = getint(),v = getint();
e[u].push_back(v); e[v].push_back(u);
}
dfs(1,0);
f[1] = it; it += (len[1] << 1) + 1;
g[1] = it; it += (len[1] << 1) + 1;
dp(1,0);
printf("%lld\n",ans);
return 0;
}