CodeForces 1109D Sasha and Interesting Fact from Graph Theory
题目大意
给定 N N N个结点,编号为 1 , 2 , … , N 1,2,\ldots,N 1,2,…,N,再给定两个点 A , B A,B A,B,要求这两个点的距离为 M M M,且树中的每一条边权均为 1 1 1到 M M M的整数,求这样能构造的树的方案数。
分析
显然可以知道 A , B A,B A,B是什么值对整个答案的计算没有影响。
考虑按如下方法构造:
即先构造出一条有 x x x个结点链,再在这条链上随意连接子树。
这个链的方案数为 A N − 2 x − 1 A_{N-2}^{x-1} AN−2x−1。
那么在这条链上,分配的方案数为 C M − 1 x − 1 C_{M-1}^{x-1} CM−1x−1。
这一部分的方案数为 A N − 2 x − 1 ⋅ C M − 1 x − 1 A_{N-2}^{x-1}\cdot C_{M-1}^{x-1} AN−2x−1⋅CM−1x−1
考虑计算剩余的结点的方案数:
利用Prufer序列的性质,我们将构造出来的链顺序标上 N − x , N − x + 1 , ⋯ , N N-x,N-x+1,\cdots,N N−x,N−x+1,⋯,N,这样一来,将我们构造出的树的Prufer序列的最后 x − 2 x-2 x−2位一定就是我们所构造出来的链。
那么考虑前 N − x N-x N−x个结点,它们在序列中的顺序是随意的,不难得出这 N − x N-x N−x个结点随意排列的方案数为 N N − x − 2 N^{N-x-2} NN−x−2。此外,又由于这个序列中除了最后一个点是固定的,这个序列的剩余的点必定是任意值,所以这个方案数为 ( x + 1 ) ⋅ N N − x − 2 (x+1)\cdot N^{N-x-2} (x+1)⋅NN−x−2。
好像网上有些说可以用广义Cayley定理来解决这个问题。。。
最后考虑边权,因为我们事先已经固定了 x x x条边的边权,其他的边随意选取即可,所以这个方案数就是 M N − x − 1 M^{N-x-1} MN−x−1。
所以最后的总方案数为 ∑ i = 1 N − 1 A N − 2 x − 1 ⋅ C M − 1 x − 1 ⋅ ( x + 1 ) ⋅ N N − x − 2 ⋅ M N − x − 1 \sum_{i=1}^{N-1}{}A_{N-2}^{x-1}\cdot C_{M-1}^{x-1}\cdot(x+1)\cdot N^{N-x-2}\cdot M^{N-x-1} i=1∑N−1AN−2x−1⋅CM−1x−1⋅(x+1)⋅NN−x−2⋅MN−x−1
此外,注意到 i = N − 1 i=N-1 i=N−1时有一项的次数是负数,特判一下就是了。
参考代码
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long ll;
const int Maxn = 1e6;
const int Mod = 1e9 + 7;
ll fac[Maxn + 5], inv_fac[Maxn + 5];
ll QuickPow(ll a, ll k) {
ll ret = 1;
while(k) {
if(k & 1) ret = ret * a % Mod;
a = a * a % Mod;
k >>= 1;
}
return ret;
}
void Init() {
fac[0] = inv_fac[0] = 1;
for(int i = 1; i <= Maxn; i++)
fac[i] = fac[i - 1] * i % Mod;
inv_fac[Maxn] = QuickPow(fac[Maxn], Mod - 2);
for(int i = Maxn - 1; i >= 1; i--)
inv_fac[i] = inv_fac[i + 1] * (i + 1) % Mod;
}
ll C(int n, int m) {
if(n < 0 || m < 0 || n < m) return 0;
return fac[n] * inv_fac[m] % Mod * inv_fac[n - m] % Mod;
}
ll A(int n, int m) {
if(n < 0 || m < 0 || n < m) return 0;
return fac[n] * inv_fac[n - m] % Mod;
}
int N, M;
int main() {
#ifdef LOACL
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
#endif
Init();
int t1, t2;
scanf("%d %d %d %d", &N, &M, &t1, &t2);
ll ans = 0;
for(int i = 1; i < N; i++) {
ll tmp = A(N - 2, i - 1) * C(M - 1, i - 1) % Mod;
tmp = tmp * QuickPow(M, N - i - 1) % Mod;
if(i != N - 1) tmp = tmp * QuickPow(N, N - i - 2) % Mod * (i + 1) % Mod;
ans = (ans + tmp) % Mod;
}
printf("%lld\n", ans);
return 0;
}