题目大意:在一棵无边权的树上选三个点,使得两两点对的距离相等,问有多少种选法。
题解:
答案的形式必然如图所示,当确定两个点,这两个点到其 lca
的距离都等于
p
p
p,那么第三个点一定来自不同子树,且到 lca 的距离也等于
p
p
p
考虑树形 dp:当前根结点为 u,合并第 i 棵子树时:在前 i - 1 棵子树里找一个点
z
z
z,在第 i 棵子树里找一个点对
(
x
,
y
)
(x,y)
(x,y),其 lca 为 f,使得
d
i
s
(
i
,
f
)
=
d
i
s
(
x
,
f
)
=
d
i
s
(
y
,
f
)
dis(i,f) = dis(x,f) = dis(y,f)
dis(i,f)=dis(x,f)=dis(y,f)
设 d i s ( x , f ) = d i s ( y , f ) = d dis(x,f) = dis(y,f) = d dis(x,f)=dis(y,f)=d,对于前 i - 1 棵子树中 深度为 i 的点,它只能与 d i s ( f , u ) = d − i dis(f,u) = d - i dis(f,u)=d−i 的点对构成一个答案。(反过来,在第 i 棵树中找一个点,前 i - 1棵树中找点对也是一样)
令状态
t
p
[
u
]
[
k
]
tp[u][k]
tp[u][k]表示 u 结点为根节点,
(
x
,
y
)
(x,y)
(x,y) 到其 lca 的距离都是 d,lca 到 u的距离为
d
−
k
d - k
d−k
(相当于两个点到其lca 距离相等,第三个点到u点的距离应为 k 的点对数)
令 d p [ u ] [ k ] dp[u][k] dp[u][k] 表示 u为根结点,与 u 距离为 k 的结点数量。
那么答案就是 ∑ v ∈ s o n u ∑ i = 0 l e n [ u ] ( d p [ u ] [ i ] ∗ t p [ v ] [ i + 1 ] + d p [ v ] [ i − 1 ] ∗ t p [ u ] [ i ] ∗ [ i > 0 ] ) \displaystyle\sum_{v \in son_u}\sum_{i = 0}^{len[u]}(dp[u][i] * tp[v][i + 1] +dp[v][i - 1] *tp[u][i] * [i > 0]) v∈sonu∑i=0∑len[u](dp[u][i]∗tp[v][i+1]+dp[v][i−1]∗tp[u][i]∗[i>0])
转移方程:
1.
d
p
[
u
]
[
i
]
=
∑
v
∈
s
o
n
u
d
p
[
v
]
[
i
−
1
]
dp[u][i] = \displaystyle\sum_{v \in son_u}dp[v][i - 1]
dp[u][i]=v∈sonu∑dp[v][i−1]
2.
t
p
[
u
]
[
i
]
=
∑
v
∈
s
o
n
u
t
p
[
v
]
[
i
+
1
]
+
d
p
[
v
]
[
i
−
1
]
∗
d
p
[
u
]
[
i
]
tp[u][i] = \displaystyle\sum_{v \in son_u} tp[v][i + 1] +dp[v][i - 1] * dp[u][i]
tp[u][i]=v∈sonu∑tp[v][i+1]+dp[v][i−1]∗dp[u][i]
dp[u][i] * dp[v][i - 1]代表lca为 u的情况)
如果直接树形dp,复杂度为 n 2 n^2 n2,这时长链剖分就派上用场,可以将这种dp优化到 O ( n ) O(n) O(n)
长链剖分的做法是
O
(
1
)
O(1)
O(1) 继承重儿子的 dp 信息,这需要用到指针,动态分配 dp 空间。
对于轻儿子,暴力计算,然后和树形dp一样的合并方式,复杂度就 神奇地 降下来了。
对dp注意一下边界,避免重复计数。
(计数类树形dp只要处理好能不遗漏不重复的统计所有情况的答案,就能保证正确,对每一种情况都要分析,不能遗漏)
由于
t
p
tp
tp 转移是逆过来的,因此tp的继承也要逆过来,做法是每次开两倍的空间,使得 tp 有前后两段可以用,具体见代码:
代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 10;
vector<int> g[maxn];
typedef long long ll;
ll tmp[maxn << 2],*id,*dp[maxn],*tp[maxn];
int len[maxn],son[maxn],n;
ll ans = 0;
void prework(int u,int fa) {
len[u] = 0;son[u] = 0;
for(int i = 0; i < g[u].size(); i++) {
int it = g[u][i];
if(it == fa) continue;
prework(it,u);
if(son[u] == 0 || len[son[u]] < len[it])
son[u] = it;
}
len[u] = len[son[u]] + 1;
}
void dfs(int u,int fa) {
dp[u][0] = 1;
if(son[u]) {
tp[son[u]] = tp[u] - 1;dp[son[u]] = dp[u] + 1;
dfs(son[u],u);
ans += tp[u][0];
}
for(int i = 0; i < g[u].size(); i++) {
int it = g[u][i];
if(it == fa || it == son[u]) continue;
dp[it] = id; id += len[it] << 1;
tp[it] = id; id += len[it];
dfs(it,u);
for(int i = 0; i < len[it]; i++) {
if(i > 0) ans += 1ll * dp[u][i - 1] * tp[it][i];
ans += 1ll * tp[u][i + 1] * dp[it][i];
}
for(int i = 0; i <= len[it]; i++) {
if(i < len[it] - 1)
tp[u][i] += tp[it][i + 1];
if(i) {
tp[u][i] += dp[it][i - 1] * dp[u][i];
dp[u][i] += dp[it][i - 1];
}
}
}
}
int main() {
scanf("%d",&n);
for(int i = 1; i < n; i++) {
int x,y;scanf("%d%d",&x,&y);
g[x].push_back(y);
g[y].push_back(x);
}
prework(1,0);
id = tmp,dp[1] = id,id += len[1] << 1;
tp[1] = id,id += len[1];
dfs(1,0);
printf("%lld\n",ans);
return 0;
}