第三题:T3树的匹配
标签:树形
D
P
DP
DP、乘法逆元
题意:给定一棵有
n
n
n个节点的树,根节点为
1
1
1。求这棵树的最大匹配数,并统计最大匹配数情况下的方案数。最终结果,对
1
0
9
+
7
10^9+7
109+7取余。树的匹配,指的是具有父子关系的点,两两组成一对,每个点只能在一个配对里。
题解:很明显这是一道树形动态规划。
d
p
[
u
]
[
0
/
1
]
dp[u][0/1]
dp[u][0/1]:节点编号为
u
u
u的子树中,
u
u
u不配对/配对时其子树的最大匹配数。
d
p
[
u
]
[
0
]
=
∑
v
M
a
x
(
d
p
[
v
]
[
0
]
,
d
p
[
v
]
[
1
]
)
dp[u][0] = \sum\nolimits_{v}Max(dp[v][0],dp[v][1])
dp[u][0]=∑vMax(dp[v][0],dp[v][1])
当
u
u
u不配对的时候,把所有孩子节点
v
v
v最大匹配数(
v
v
v不配对或者配对的情况下中的最大值)加起来。
统计完
d
p
[
u
]
[
0
]
dp[u][0]
dp[u][0]之后,
d
p
[
u
]
[
1
]
dp[u][1]
dp[u][1]表示的是配对的情况,需要从所有孩子节点
v
v
v没有配对的情况转移过来;当前的孩子节点
v
′
v'
v′配对的情况下,其他的孩子节点还是拿
m
a
x
(
d
p
[
v
]
[
0
]
,
d
p
[
v
]
[
1
]
)
max(dp[v][0],dp[v][1])
max(dp[v][0],dp[v][1])。
d
p
[
u
]
[
1
]
=
M
a
x
j
(
d
p
[
u
]
[
0
]
−
m
a
x
(
d
p
[
v
]
[
0
]
,
d
p
[
v
]
[
1
]
)
+
d
p
[
v
]
[
0
]
+
1
)
dp[u][1] = Max_j(dp[u][0]-max(dp[v][0], dp[v][1]) + dp[v][0] + 1)
dp[u][1]=Maxj(dp[u][0]−max(dp[v][0],dp[v][1])+dp[v][0]+1)
这边的
+
1
+1
+1指的是
u
u
u和当前枚举的孩子节点组成的一个新的匹配数。
可以根据下面的图,同学自己再推一推。
c
n
t
[
u
]
[
0
/
1
]
cnt[u][0/1]
cnt[u][0/1]:节点编号为
u
u
u的子树中,
u
u
u不配对/配对时其子树的最大匹配数下的方案数
c
n
t
[
u
]
[
0
]
cnt[u][0]
cnt[u][0]其实比较好想,
u
u
u节点既然不配对了,那么只需要管它的所有孩子节点
v
v
v,我们是求最大匹配数的情况下方案数,所以我们得比较一下
d
p
[
v
]
[
0
]
dp[v][0]
dp[v][0]和
d
p
[
v
]
[
1
]
dp[v][1]
dp[v][1]中哪个传上来的最大匹配数大,或者一样大。
所以这边我写了一个
s
e
l
e
c
t
(
v
)
select(v)
select(v)函数:
s
e
l
e
c
t
(
v
)
=
d
p
[
v
]
[
0
]
>
=
d
p
[
v
]
[
1
]
∗
c
n
t
[
v
]
[
0
]
+
d
p
[
v
]
[
0
]
<
=
d
p
[
v
]
[
1
]
∗
c
n
t
[
v
]
[
1
]
select(v)=dp[v][0]>=dp[v][1]*cnt[v][0]+dp[v][0]<=dp[v][1]*cnt[v][1]
select(v)=dp[v][0]>=dp[v][1]∗cnt[v][0]+dp[v][0]<=dp[v][1]∗cnt[v][1]
ll select(ll v) {
if (dp[v][0] > dp[v][1]) return cnt[v][0];
else if (dp[v][0] < dp[v][1]) return cnt[v][1];
else return (cnt[v][0] + cnt[v][1]) % mod;
}
接下来考虑一下,对于
u
u
u节点来说,所有子树的方案数传上来,是不是得乘积一下(乘法原理)。
公式化:
c
n
t
[
u
]
[
0
]
=
∏
v
(
s
e
l
e
c
t
(
v
)
)
cnt[u][0]=\prod\nolimits_v(select(v))
cnt[u][0]=∏v(select(v))
然后这边我们从上面的
d
p
[
u
]
[
1
]
dp[u][1]
dp[u][1]的情况,推一推,上面的是加法原理,然后
u
u
u和某一个孩子节点
v
v
v进行匹配的情况下,我们是先减去
m
a
x
(
d
p
[
v
]
[
0
]
,
d
p
[
v
]
[
1
]
)
max(dp[v][0],dp[v][1])
max(dp[v][0],dp[v][1]),那同理 这边乘法原理,我们要去求
c
n
t
[
u
]
[
1
]
cnt[u][1]
cnt[u][1]是不是得先除一下。
那接下来我们看实际在什么情况下进行转移,如果在 d p [ u ] [ 1 ] dp[u][1] dp[u][1]能够变的更大的时候,当然直接从 c n t [ v ] [ 0 ] cnt[v][0] cnt[v][0]( v v v不选的情况)的地方转移过来。即 c n t [ u ] [ 1 ] = c n t [ u ] [ 0 ] / s e l e c t ( v ) ∗ c n t [ v ] [ 0 ] cnt[u][1] = cnt[u][0]/select(v)*cnt[v][0] cnt[u][1]=cnt[u][0]/select(v)∗cnt[v][0]
那如果当前和之前 d p [ u ] [ 1 ] dp[u][1] dp[u][1]的值一样,那也就是说有全新的也能够组成最大匹配数的情况,是不是得求和加上来。即 c n t [ u ] [ 1 ] = c n t [ u ] [ 1 ] + c n t [ u ] [ 0 ] / s e l e c t ( v ) ∗ c n t [ v ] [ 0 ] cnt[u][1] = cnt[u][1]+cnt[u][0]/select(v)*cnt[v][0] cnt[u][1]=cnt[u][1]+cnt[u][0]/select(v)∗cnt[v][0]
需要额外考虑的点是,整道题的方案数最终结果需要取模,除法运算是不符合同余定理的,我们得使用逆元给它转一下。这边通过快速幂去实现一下费马小定理求逆元的。
费马小定理:如果
p
p
p是一个质数,而整数
a
a
a不是
p
p
p的倍数,则有
a
p
−
1
≡
1
(
m
o
d
p
)
a^{p-1}≡1\ (mod \ \ p)
ap−1≡1 (mod p)。
推导得
a
p
−
2
=
i
n
v
(
a
)
(
m
o
d
p
)
a^{p-2}=inv(a) \ (mod \ \ p)
ap−2=inv(a) (mod p)=>
i
n
v
(
a
)
=
a
p
−
2
(
m
o
d
p
)
inv(a)= a^{p-2}\ (mod \ \ p)
inv(a)=ap−2 (mod p)
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 1e9 + 7;
ll n, x;
ll dp[200005][2], cnt[200005][2];
// dp[i][0/1]: i这个顶点不放/放到最终匹配方案中
vector <ll> e[200005];
ll fast_power(ll a, ll b) {
ll ans = 1;
a %= mod;
while (b) {
if (b & 1) ans = (ans * a) % mod;
a = (a * a) % mod;
b >>= 1;
}
return ans;
}
ll select(ll v) {
if (dp[v][0] > dp[v][1]) return cnt[v][0];
else if (dp[v][0] < dp[v][1]) return cnt[v][1];
else return (cnt[v][0] + cnt[v][1]) % mod;
}
void dfs(ll u) {
cnt[u][0] = 1;
for (int i = 0; i < e[u].size(); i++) {
ll v = e[u][i];
dfs(v);
dp[u][0] += max(dp[v][0], dp[v][1]);
cnt[u][0] = (cnt[u][0] * select(v)) % mod;
}
for (int i = 0; i < e[u].size(); i++) {
ll v = e[u][i];
ll otrher = dp[u][0] - max(dp[v][0], dp[v][1]);
if (dp[u][1] < dp[v][0] + otrher + 1) {
dp[u][1] = dp[v][0] + otrher + 1;
ll k = cnt[u][0] * fast_power(select(v), mod - 2) % mod;
cnt[u][1] = (k * cnt[v][0] + mod) % mod;
}
else if (dp[u][1] == dp[v][0] + otrher + 1) {
ll k = cnt[u][0] * fast_power(select(v), mod - 2) % mod;
cnt[u][1] = (cnt[u][1] + (k * cnt[v][0]) % mod + mod) % mod;
}
}
}
int main() {
cin >> n;
for (int i = 2; i <= n; i++) {
cin >> x;
e[x].push_back(i);
}
dfs(1);
cout << max(dp[1][0], dp[1][1]) << endl;
cout << select(1) << endl;
return 0;
}