题目大意:
就是给你一个 2 n 2n 2n个点的完全图,从这个图里面抽出 2 n − 1 2n-1 2n−1条边,这些边形成一颗树,现在问你剩下的图里面点进行完美匹配有多少种方案?
解题思路:
- 一开始被完美匹配給限定死了思维。其实我们知道实际上就是对原图进行选边,有的边不能选而已,那么其实我们直接容斥就行了怎么容斥呢?
- 就是对于上面那些树边,我们目的是要求不选,并且两两匹配,如果我们不考虑这个直接全部选的话,那么我们就可能会选到1条树边,2条树边,3条树边依次类推,那么我们就可以直接减就好了
- 我们定义一定选其中 x x x条树边,那么可能会包含 x + 1 x+1 x+1, x + 2 x+2 x+2条…
- 对于一定选多少条树边,我们可以直接跑树形dp,就是 d p [ i ] [ j ] [ 0 / 1 ] dp[i][j][0/1] dp[i][j][0/1]:表示以第 i i i个点为根,下面有 j j j个匹配,并且 i i i节点是否匹配的方案数。那么转移就很明显了
- 假如现在我们已经一定选择
k
k
k条边了,那么就匹配了
2
k
2k
2k个点,剩下
2
n
−
2
k
2n-2k
2n−2k个点进行匹配,方案数就是
s u m = C ( 2 n − 2 k , n − k ) ∗ A n − k 2 n − k sum=\frac{C(2n-2k,n-k)*A^{n-k}}{2^{n-k}} sum=2n−kC(2n−2k,n−k)∗An−k
就是先把集合分成两部分,对其中一部分进行全排列,然后每一对反过来组合是一样的就是要除以 2 n − k 2^{n-k} 2n−k
AC code
#include <bits/stdc++.h>
#define mid ((l + r) >> 1)
#define Lson rt << 1, l , mid
#define Rson rt << 1|1, mid + 1, r
#define ms(a,al) memset(a,al,sizeof(a))
#define log2(a) log(a)/log(2)
#define lowbit(x) ((-x) & x)
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define INF 0x3f3f3f3f
#define LLF 0x3f3f3f3f3f3f3f3f
#define f first
#define s second
#define endl '\n'
using namespace std;
const int N = 2e6 + 10, mod = 998244353;
const int maxn = 500010;
const long double eps = 1e-5;
const int EPS = 500 * 500;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PII;
typedef pair<ll,ll> PLL;
typedef pair<double,double> PDD;
template<typename T> void read(T &x) {
x = 0;char ch = getchar();ll f = 1;
while(!isdigit(ch)){if(ch == '-')f*=-1;ch=getchar();}
while(isdigit(ch)){x = x*10+ch-48;ch=getchar();}x*=f;
}
template<typename T, typename... Args> void read(T &first, Args& ... args) {
read(first);
read(args...);
}
int k;
vector<int>g[maxn];
int siz[maxn];
int tmp[4002][2];
int dp[4002][4002][2]; // dp[i][j][0/1] 表示i这个子树里面有选出了j个匹配且i是否匹配的方案数
int fac[maxn], inv[maxn];
ll qim(ll a, ll b) {
ll res = 1;
while(b) {
if(b & 1) res = a * res % mod;
a = a * a % mod;
b >>= 1;
}
return res % mod;
}
inline void init() {
fac[0] = 1;
for(int i = 1; i < maxn; ++ i) fac[i] = 1ll * fac[i-1] * i % mod;
inv[0] = 1;
inv[1] = qim(2,mod-2);
for(int i = 2; i < maxn; ++ i) inv[i] = 1ll * inv[i-1] * inv[1] % mod;
}
inline ll C(int a, int b) {
if (a < b) return 0;
return 1ll * fac[a] * qim(fac[b],mod-2) % mod * qim(fac[a-b],mod-2)%mod;
}
inline void dfs(int u, int fa) {
siz[u] = 1;
dp[u][0][0] = 1;
for(auto it : g[u]) {
if(it == fa) continue;
dfs(it,u);
ms(tmp,0);
for(int i = 0; i <= siz[u]/2; ++ i)
for(int j = 0; j <= siz[it]/2; ++ j) {
tmp[i+j][0] = (tmp[i+j][0] + 1ll*dp[u][i][0]*(dp[it][j][0]+dp[it][j][1])%mod)%mod;
tmp[i+j][1] = (tmp[i+j][1] + 1ll*dp[u][i][1]*(dp[it][j][0]+dp[it][j][1])%mod)%mod;
tmp[i+j+1][1] = (tmp[i+j+1][1] + 1ll*dp[u][i][0]*dp[it][j][0]%mod)%mod;
}
for(int i = 0; i <= siz[u]/2+siz[it]/2+1; ++ i) {
dp[u][i][0] = tmp[i][0];
dp[u][i][1] = tmp[i][1];
}
siz[u] += siz[it];
}
}
int main() {
IOS;
init();
cin >> k;
k *= 2;
for(int i = 1; i < k; ++ i) {
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1,0);
ll ans = 0;
for(int i = 0; i <= k/2; ++ i) {
ll res = (dp[1][i][0] + dp[1][i][1]) % mod;
int lim = k - 2 * i;
ll cnt = 1ll*C(lim,lim/2) * fac[lim/2] % mod * inv[lim/2] % mod;
if(i & 1)
ans = (ans - cnt * res % mod + mod) % mod;
else ans = (ans + cnt * res % mod) % mod;
}
cout << ans;
return 0;
}