前置芝士
Kirchhoff 矩阵树定理
Kirchhoff矩阵树定理解决了一个问题:对于一个确定的无向图,其究竟有多少个生成树?
对于一个无向图,我们拥有其邻接矩阵
A
\bf{A}
A。
这里的邻接矩阵允许重边,第
i
i
i 行第
j
j
j 列的值代表着点
i
i
i 到点
j
j
j 有几条边。
不允许自环。
我们定义一个无向图的度数矩阵 D \bf{D} D 为,第 i i i 行第 i i i 列上的数字是点 i i i 的度数,其余的格子都为 0 0 0 的矩阵。
我们定义一个图的 Kirchhoff矩阵 K = D − A {\bf{K}} = {\bf{D}} - {\bf{A}} K=D−A。
这个矩阵同时去掉任意一行一列,剩下的这个子矩阵的行列式的绝对值,就是该无向图的生成树个数。
行列式
行列式可以被理解为列向量夹的几何体的体积,用 det ( A ) \det({\bf{A}}) det(A) 表示 A \bf{A} A 这个方阵的行列式。
这里不讲太细,细一点的可以去看我的另外一篇博客。
在这里,我们只需要知道几个简单的点:
- 每一个方阵都有自己的行列式,公式是 det ( A ) = ∑ σ ∈ S n sgn ( σ ) ∏ i = 1 n a i , σ ( i ) \displaystyle \det({\bf{A}}) = \sum_{\sigma \in S_n} \operatorname{sgn}(\sigma) \prod_{i=1}^n a_{i,\sigma(i)} det(A)=σ∈Sn∑sgn(σ)i=1∏nai,σ(i)。
- 三角矩阵的行列式是可以以很小的复杂度计算出来的。具体方法是其对角线之积。
- 方阵的行列式有如下几个性质:
- 矩阵转置,行列式不变;
- 矩阵行(列)交换,行列式取反;
- 矩阵行(列)相加减,行列式不变;
- 矩阵行(列)所有元素同时乘以数 k k k,行列式也乘 k k k。
- 通过上面的几个操作(其实就是高斯消元)可以将一个矩阵消为一个三角矩阵,而反过来则可以从三角矩阵变回原矩阵。
那么我们存储一个系数,在消元的时候维护这个系数,直到最后得到三角矩阵的时候再将系数乘在这个我们能够简单求出的行列式上,就可以得到原矩阵的行列式了。
题目解读
题目要求我们求出一个无向图的生成树数量,并且该生成树需要满足其中的每一条边都属于一个不同的集合。
那么我们考虑容斥,将“恰好”改为“至多”。
这样我们枚举由
n
−
1
n-1
n−1 个集合构成的无向图的时候会记录上由
n
−
2
n-2
n−2 个集合构成的无向图的结果,需要减去这部分的影响。
然后枚举由
n
−
2
n-2
n−2 个集合构成的无向图,同时会记录上由
n
−
3
n-3
n−3 个集合构成的无向图的结果。
这样一直容斥下去,直到最后枚举由
1
1
1 个集合构成的无向图的结果的时候,容斥就可以停止了。
每一次枚举的时候,我们都需要跑矩阵树定理。每一次跑矩阵树定理都是 O ( ( n − 1 ) 3 ) O((n-1)^3) O((n−1)3) 的,我们需要跑 2 n − 1 2^{n-1} 2n−1 次,所以最终的复杂度是 O ( 2 n − 1 ( n − 1 ) 3 ) O(2^{n-1}(n-1)^3) O(2n−1(n−1)3) 的。
参考代码:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 20, M = N * N;
const ll mod = 1e9 + 7;
int n, maxn;
int m[N];
int u[N][M], v[N][M];
int sz[(1 << 18)];
ll krh[N][N];
ll qpow(ll a, ll x)
{
ll res = 1;
while(x)
{
if(x & 1)res = (res * a) % mod;
a = (a * a) % mod;
x >>= 1;
}
return res;
}
inline ll det()
{
ll res = 1;
int flag = 0;
for(int i = 1; i <= n - 1; i++)
{
if(krh[i][i] == 0)
{
for(int j = i + 1; j <= n - 1; j++)
{
if(krh[j][i] == 0)continue;
flag ^= 1;
for(int k = 1; k <= n - 1; k++)
swap(krh[i][k], krh[j][k]);
}
}
for(int j = i; j <= n - 1; j++)
{
if(krh[j][i] == 0)continue;
ll inv = qpow(krh[j][i], mod - 2);
res = (res * krh[j][i]) % mod;
for(int k = i; k <= n - 1; k++)
krh[j][k] = (krh[j][k] * inv) % mod;
}
for(int j = i + 1; j <= n - 1; j++)
{
if(krh[j][i] == 0)continue;
for(int k = i; k <= n - 1; k++)
krh[j][k] = (krh[j][k] - krh[i][k] + mod) % mod;
}
}
for(int i = 1; i <= n - 1; i++)
res = (res * krh[i][i]);
return (flag) ? (mod - res) % mod : res;
}
int main()
{
scanf("%d", &n);
maxn = (1 << (n - 1)) - 1;
for(int i = 1; i < n; i++)
{
scanf("%d", &m[i]);
for(int j = 1; j <= m[i]; j++)
scanf("%d%d", &u[i][j], &v[i][j]);
}
for(int i = 1; i <= maxn; i++)
sz[i] = sz[i >> 1] + (i & 1);
ll res = 0;
for(int i = 1; i <= maxn; i++)
{
memset(krh, 0, sizeof(krh));
for(int j = 1, p = i; p; p >>= 1, j++)
{
if((p & 1) == 0)continue;
for(int k = 1; k <= m[j]; k++)
{
int U = u[j][k], V = v[j][k];
krh[U][U]++, krh[V][V]++;
krh[U][V] = (krh[U][V] + mod - 1) % mod;
krh[V][U] = (krh[V][U] + mod - 1) % mod;
}
}
res = (res + mod + det() * ((n - sz[i]) % 2 ? 1 : -1)) % mod;
}
printf("%lld\n", res);
return 0;
}