D1. LuoTianyi and the Floating Islands (Easy Version)
题意:
给你一棵n个节点的树,随机选择k的不同的节点,到这k个节点的距离和最小的节点称为“好节点”,让你求这些好节点的期望值。
思路:
这题光是想清楚题目意思就很麻烦了,第一次做让求期望的题。
考虑树肯定是不如考虑链来的方便,所以我们先考虑链状情况吧。
-
k为奇数,好节点只有一个,证明比较复杂,不过很容易看出来,所以期望为1
-
k为偶数,好节点的个数就包括1、2、3…n了,所以这种情况下的期望为:
- 1 ∗ p 1 + 2 ∗ p 2 + 3 ∗ p 3 + . . . . + ( n − 1 ) ∗ p n − 1 + n ∗ p n 1*p_{1}+2*p_{2}+3*p_{3}+....+(n-1)*p_{n-1}+n*p_{n} 1∗p1+2∗p2+3∗p3+....+(n−1)∗pn−1+n∗pn
但是这样还是不好求,因为 p i p_{i} pi是不好求的,这种情况下有一个常用的解法,就是考虑每个点的贡献,即它会被计算几次?每当它的子树(包括自己)中被选择了 k 2 \frac{k}{2} 2k,它的非子树中也选择了 k 2 \frac{k}{2} 2k节点时就会被计算一次。
那么这种计算有多少种呢?显然有 C s z [ x ] k 2 ∗ C n − s z [ x ] k 2 C^{\frac{k}{2}}_{sz[x]}*C^{\frac{k}{2}}_{n - sz[x]} Csz[x]2k∗Cn−sz[x]2k种。
这题比较麻烦的还有求组合数和最后的取模逆操作,居然还用到了费马小定理,属实惊艳到我了。
代码
#include<iostream>
#include<cstring>
#include<algorithm>
#include<vector>
#include<set>
#include<map>
#include<unordered_map>
using namespace std;
#define _int64 long long
const int faclim = 310000, mod = 1e9 + 7;
_int64 fac[faclim], invfac[faclim];
_int64 power(int x, int y){
_int64 res = 1;
for(int i = 30; i >= 0; i --){
res = res * res % mod;
if((y >> i) & 1) res = res * x % mod;
}
return res;
}
_int64 inv(int x){
return power(x, mod - 2); // 费马小定理
}
_int64 C(int x, int y){
return fac[x] * invfac[y] % mod * invfac[x - y] % mod;
}
void init(){
fac[0] = 1;
for(int i = 1; i < faclim; i ++)
fac[i] = fac[i - 1] * i % mod;
invfac[faclim - 1] = inv(fac[faclim - 1]);
for(int i = faclim - 1; i > 0; i --)
invfac[i - 1] = invfac[i] * i % mod;
}
// 以上都是板子,直接用即可。
const int N = 2e5 + 10;
vector<int> a[N];
int p[N], sz[N];
int main(){
int n, k, x, y, i, j;
scanf("%d %d", &n, &k);
for(int i = 0; i < n; i ++){
p[i] = -1;
a[i].clear();
}
init();
for(i = 0; i < n - 1; i ++){
scanf("%d %d", &x, &y);
x --; y --;
a[x].push_back(y);
a[y].push_back(x);
}
if(k % 2 == 1){
puts("1");
return 0;
}
vector<int> q;
q.push_back(0);
for(i = 0; i < q.size(); i ++){
int x = q[i];
for(int j = 0; j < a[x].size(); j ++){
int y = a[x][j];
if(y == p[x]) continue;
p[y] = x;
q.push_back(y);
}
}
_int64 ans = 0;
int half = k / 2;
for(i = q.size() - 1; i >= 0; i --){
int x = q[i];
sz[x] = 1;
for(int j = 0; j < a[x].size(); j ++){
int y = a[x][j];
if(y == p[x]) continue;
sz[x] += sz[y];
}
if(sz[x] >= half && (n - sz[x]) >= half){
ans += (_int64)C(sz[x], half) * C(n - sz[x], half) % mod;
ans %= mod;
}
}
_int64 tmp = C(n, k);
ans += tmp;
ans %= mod;
// 要求 x 使得 x * tmp % mod = ans
// inv(tmp) = tmp^(mod - 2) => inv(tmp) * tmp % mod = 1
ans *= inv(tmp);
ans %= mod;
// ans * tmp % mod = ans 满足了要求
printf("%lld\n", ans);
return 0;
}