题目大意:
给你一颗节点数为n的树, 选定一个子集
如果一个未被选择的点度数为1, 该点被删除
如果一个未被选择的点度数为2, 该点被删除, 并将其相连的两个点连接。
求对每一个k, 选定的k子集后期望剩下的点数, 模
109+7
10
9
+
7
。
(n≤5000)
(
n
≤
5000
)
题目思路:
考虑选定了一个子集大小k后, 考虑每一个点的贡献
对于一个度数小于等于2的点, 他要留下来(对答案的分子产生贡献1)的情况只有他自己被选中, 否则一定会被删, 故有C(n - 1, k - 1)种。
对于一个度数大于2的点, 他要被删掉的情况, 当且仅当所有k个点都选在他的某两个孩子子树内。 这种情况下就是, 其他子树的点会从叶子开始一直删下来删光, 然后他只剩下了度数2, 也被删掉了。
设它有m个孩子, 每个孩子有sz[i], 则被删掉的方案数是
后面那一项是减去算重的部分, 从第一项的式子中可以看出,k个点都选在同一个子树内的情况, 即每个C(sz[i], k)会被算(m-1)次, 故要减去(m-2)个。
然后度数大于2点的贡献就是C(n, k)减去被删掉的方案了。
最后考虑对于每个k来求答案
对于第一种情况, 我们可以预先数出有多少个度数小于等于2的点, 然后乘以C(n - 1, k - 1)即可。
对于第二种情况, 同样可以预处理出每个i<=n, C(i, k)前面的系数, 然后O(n)的算一遍即可。
故总的复杂度是O(n^2)的。
PS:
在dfs部分, 每个节点虽然是O(孩子个数^ 2)的枚举计算, 但是总的复杂度依然是O(n^2), 这种复杂度在树形dp中也很常见。
Code:
#include <map>
#include <set>
#include <map>
#include <bitset>
#include <cmath>
#include <queue>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define ll long long
#define db double
#define fi first
#define se second
#define mp(x, y) make_pair(x, y)
#define ls (x << 1)
#define rs ((x << 1) | 1)
#define mid ((l + r) >> 1)
using namespace std;
const int N = (int)5050;
const int mo = (int)1e9 + 7;
int n;
int cnt, lst[N], nxt[N * 2], to[N * 2];
int deg[N], sz[N], tot; ll tim[N], C[N][N];
ll pw(ll x, ll k){
ll ret = 1;
while (k){
if (k & 1) ret = ret * x % mo;
x = x * x % mo;
k >>= 1;
}
return ret;
}
void add(int u, int v){
nxt[++ cnt] = lst[u]; lst[u] = cnt; to[cnt] = v;
nxt[++ cnt] = lst[v]; lst[v] = cnt; to[cnt] = u;
}
void dfs(int u, int fa){
sz[u] = 1;
for (int j = lst[u]; j; j = nxt[j]){
int v = to[j];
if (v == fa) continue;
dfs(v, u);
sz[u] += sz[v];
}
if (deg[u] <= 2) tot ++;
else{
for (int j = lst[u]; j; j = nxt[j]){
int v1 = to[j];
if (v1 == fa) continue;
for (int k = lst[u]; k && to[k] != v1; k = nxt[k]){
int v2 = to[k];
if (v2 == fa) continue;
(tim[sz[v1] + sz[v2]] -= 1) %= mo;
}
(tim[sz[v1]] += (deg[u] - 2)) %= mo;
}
if (u != 1){
for (int j = lst[u]; j; j = nxt[j]){
int v = to[j];
if (v == fa) continue;
(tim[sz[v] + n - sz[u]] -= 1) %= mo;
}
(tim[n - sz[u]] += (deg[u] - 2)) %= mo;
}
}
}
int main(){
scanf("%d", &n);
for (int i = 1, u, v; i < n; i ++){
scanf("%d %d", &u, &v);
add(u, v); deg[u] ++, deg[v] ++;
}
dfs(1, 0);
C[0][0] = 1;
for (int i = 1; i < N; i ++){
C[i][0] = 1;
for (int j = 1; j <= i; j ++)
(C[i][j] = C[i - 1][j] + C[i - 1][j - 1]) %= mo;
}
for (int k = 1; k <= n; k ++){
ll ans = 0;
(ans += tot * C[n - 1][k - 1] % mo) %= mo;
(ans += (n - tot) * C[n][k] % mo) %= mo;
for (int i = 1; i <= n; i ++)
(ans += tim[i] * C[i][k]) %= mo;
if (ans < 0) ans += mo;
ans = ans * pw(C[n][k], mo - 2) % mo;
printf("%lld\n", ans);
}
return 0;
}