【题目链接】
【思路要点】
- 写一个运用Matrix-Tree定理在取模一个大质数下解决本题的程序,通过打表找到规律:$$Ans=N^{M-1}M^{N-1}$$
- 由于模数可能很大,具体实现时需要将乘法替换为快速加。时间复杂度\(O(Log^{2}_{N})\)。
- 可以用Matrix-Tree定理用\(N\)和\(M\)来表示拉普拉斯矩阵(除去一行一列)的行列式来证明上述结论。
- 或者,我们也可以用Prufer序列来证明上述结论。由于图是一张二分图,Prufer序列中一定会有\(N-1\)个\(1\)到\(M\)之间的数和\(M-1\)个\(1\)到\(N-1\)的数,分别生成上述两类有序数集,共有\(N^{M-1}M^{N-1}\)种。
- 考虑如何用上述数集生成对应的Prufer序列。不在数集中的最小的数只有一个,若其表示的是二分图左侧的点,那么Prufer序列的第一个元素一定是右侧的点的对应数集中的第一个数;反之,若其表示的是二分图右侧的点,那么Prufer序列的第一个元素一定是左侧的点的对应数集中的第一个数。因此,我们的每一个数集(组)都能唯一对应一个合法的Prufer序列,并且显然,这样的对应是没有遗漏的。因此,我们证明了上述结论。
【代码】
/*Program till line 30*/ #include<bits/stdc++.h> using namespace std; #define MAXN 5005 template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } long long P; long long times(long long x, long long y) { if (y == 0) return 0; long long tmp = times(x, y / 2); if (y % 2 == 0) return (tmp + tmp) % P; else return (tmp + tmp + x) % P; } long long power(long long x, long long y) { if (y == 0) return 1; long long tmp = power(x, y / 2); if (y % 2 == 0) return times(tmp, tmp); else return times(tmp, times(tmp, x)); } int main() { long long n, m; read(m), read(n), read(P); cout << times(power(n, m - 1), power(m, n - 1)) << endl; return 0; } /*Use the program below to print a form of answer*/ #include<bits/stdc++.h> using namespace std; #define MAXN 505 #define P 1000000007 template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } long long a[MAXN][MAXN]; long long calc(int n) { int f = 1; for (int i = 1; i <= n; i++) { for (int j = i; j <= n; j++) if (a[j][i]) { if (i != j) { swap(a[i], a[j]); f = -f; } break; } if (a[i][i] == 0) return 0; for (int j = i + 1; j <= n; j++) { if (j == i || a[j][i] == 0) continue; while (a[j][i]) { swap(a[i], a[j]); f = -f; long long tmp = a[j][i] / a[i][i]; for (int k = 1; k <= n; k++) a[j][k] = (a[j][k] - a[i][k] * tmp % P + P) % P; } } } long long ans = 1; for (int i = 1; i <= n; i++) ans = ans * a[i][i] % P; if (f == 1) return ans; else return (P - ans) % P; } int main() { int n, m; read(n), read(m); for (int i = 1; i <= n; i++) { a[i][i] = m; for (int j = n + 1; j <= n + m; j++) a[i][j] = a[j][i] = P - 1; } for (int i = n + 1; i <= n + m; i++) a[i][i] = n; printf("%lld\n", calc(n + m - 1)); return 0; }