DP的状态设置的很厉害。
设
f
i
,
j
f_{i,j}
fi,j为只考虑子树离
i
i
i距离为
j
j
j的节点的数目。
设
g
i
,
j
g_{i,j}
gi,j为有两个点距离他们的LCA的距离为
d
d
d,并且他们的LCA离
i
i
i的距离为
d
−
j
d-j
d−j
那么转移就非常显然了。
a
n
s
+
=
f
[
x
]
[
j
]
∗
g
[
y
]
[
j
+
1
]
+
g
[
x
]
[
j
]
∗
f
[
y
]
[
j
−
1
]
ans +=f[x][j] * g[y][j + 1] + g[x][j] * f[y][j - 1]
ans+=f[x][j]∗g[y][j+1]+g[x][j]∗f[y][j−1]
g
[
x
]
[
j
+
1
]
+
=
f
[
x
]
[
j
+
1
]
∗
f
[
y
]
[
j
]
g[x][j + 1] += f[x][j + 1] * f[y][j]
g[x][j+1]+=f[x][j+1]∗f[y][j]
g
[
x
]
[
j
−
1
]
+
=
g
[
y
]
[
j
]
g[x][j - 1] += g[y][j]
g[x][j−1]+=g[y][j]
f
[
x
]
[
j
+
1
]
+
=
f
[
y
]
[
j
]
f[x][j + 1] += f[y][j]
f[x][j+1]+=f[y][j]
这东西显然可以长链剖分,然后就结束了。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <cstdlib>
#include <vector>
using namespace std;
typedef long long ll;
inline ll read() {
ll d = 0;
char s = getchar();
while (s < '0' || s > '9')
s = getchar();
while ('0' <= s && s <= '9') {
d = d * 10 + s - '0';
s = getchar();
}
return d;
}
int const N = 100005;
struct edge {
int y, next;
} e[N << 1];
int last[N], ne;
void addedge(int x, int y) {
e[++ne].y = y;
e[ne].next = last[x];
last[x] = ne;
}
ll tmp[N * 7], *f[N], *g[N], *cur = tmp;
int len[N], son[N];
int n;
ll ans;
void dfs(int x, int fa) {
for (int i = last[x]; i; i = e[i].next) {
if (e[i].y == fa)
continue;
dfs(e[i].y, x);
if (len[e[i].y] > len[son[x]])
son[x] = e[i].y;
}
len[x] = len[son[x]] + 1;
}
void DP(int x, int fa) {
if (son[x]) {
f[son[x]] = f[x] + 1;
g[son[x]] = g[x] - 1;
DP(son[x], x);
}
f[x][0] = 1;
ans += g[x][0];
for (int i = last[x]; i; i = e[i].next) {
if (e[i].y == son[x] || e[i].y == fa)
continue;
f[e[i].y] = cur;
cur += (len[e[i].y] << 1) + 2;
g[e[i].y] = cur;
cur += (len[e[i].y] << 1) + 2;
DP(e[i].y, x);
for (int j = 0; j <= len[e[i].y]; ++j) {
ans += f[x][j] * g[e[i].y][j + 1];
if (j > 0)
ans += g[x][j] * f[e[i].y][j - 1];
}
for (int j = 0; j <= len[e[i].y]; ++j) {
g[x][j + 1] += f[x][j + 1] * f[e[i].y][j];
if (j > 0)
g[x][j - 1] += g[e[i].y][j];
f[x][j + 1] += f[e[i].y][j];
}
}
}
int main() {
n = read();
for (int i = 1; i < n; ++i) {
int x = read(), y = read();
addedge(x, y);
addedge(y, x);
}
dfs(1, 0);
f[1] = cur;
cur += (len[1] << 1) + 2;
g[1] = cur;
cur += (len[1] << 1) + 2;
DP(1, 0);
cout << ans << '\n';
}