题目大意
给定一棵树,有一些询问。每次询问给出
k
k
k个点和两个数
m
,
r
m,r
m,r,表示让原树以
r
r
r为根,把这
k
k
k个点分成至多
m
m
m组,每组内不存在一个点是另一个点的祖先。求方案数膜1000000007.
n
,
Q
≤
1
0
5
,
∑
k
≤
1
0
5
,
m
≤
m
i
n
(
k
,
300
)
n,Q\le 10^5,\sum k\le 10^5,m\le min(k,300)
n,Q≤105,∑k≤105,m≤min(k,300)。
题解
显然先建虚树,并且按照给定根重新遍历虚树。刚开始SB的我想了好久怎么重新确定虚树中谁是谁的祖先……后来才发现直接把
r
r
r加进去一起建虚树就行了qaq。
然后,看数据范围似乎是个
O
(
k
m
)
O(km)
O(km)的做法?想了一会儿树形dp,感觉不太可行。那就估计是组合数学了。
先不考虑组与组之间无区别的问题(即两组分别为{1},{2}和{2},{1}实际上是相同的情况),我们给每个组设定一个编号。遍历虚树,如果某个点向上有
x
x
x个祖先,那么它可以选的编号有
m
−
x
m-x
m−x种,乘起来即可。
显然这样会重复,我们考虑去重。不妨令
f
(
m
)
f(m)
f(m)表示刚刚算出的答案,
g
(
m
)
g(m)
g(m)表示恰好分成
m
m
m个非空无区别组的方案数。那么:
f
(
m
)
=
∑
i
=
1
m
(
m
i
)
g
(
i
)
⋅
i
!
f(m)=\sum_{i=1}^m\binom mi g(i)\cdot i!
f(m)=i=1∑m(im)g(i)⋅i!
二项式反演即可得到:
g
(
m
)
=
1
i
!
∑
i
=
1
m
(
−
1
)
m
−
i
(
m
i
)
f
(
i
)
g(m)=\frac{1}{i!}\sum_{i=1}^m(-1)^{m-i}\binom mi f(i)
g(m)=i!1i=1∑m(−1)m−i(im)f(i)
于是我们可以在
O
(
k
m
)
O(km)
O(km)的时间内算出所有的
f
f
f,利用
f
f
f在
O
(
m
2
)
≤
O
(
k
m
)
O(m^2)\le O(km)
O(m2)≤O(km)的时间内算出所有的
g
g
g,直接求和就是答案。
#include <bits/stdc++.h>
namespace IOStream {
const int MAXR = 1 << 23;
char _READ_[MAXR], _PRINT_[MAXR];
int _READ_POS_, _PRINT_POS_, _READ_LEN_;
inline char readc() {
#ifndef ONLINE_JUDGE
return getchar();
#endif
if (!_READ_POS_) _READ_LEN_ = fread(_READ_, 1, MAXR, stdin);
char c = _READ_[_READ_POS_++];
if (_READ_POS_ == MAXR) _READ_POS_ = 0;
if (_READ_POS_ > _READ_LEN_) return 0;
return c;
}
template<typename T> inline void read(T &x) {
x = 0; register int flag = 1, c;
while (((c = readc()) < '0' || c > '9') && c != '-');
if (c == '-') flag = -1; else x = c - '0';
while ((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
x *= flag;
}
template<typename T1, typename ...T2> inline void read(T1 &a, T2 &...x) {
read(a), read(x...);
}
inline int reads(char *s) {
register int len = 0, c;
while (isspace(c = readc()) || !c);
s[len++] = c;
while (!isspace(c = readc()) && c) s[len++] = c;
s[len] = 0;
return len;
}
inline void ioflush() {
fwrite(_PRINT_, 1, _PRINT_POS_, stdout), _PRINT_POS_ = 0;
fflush(stdout);
}
inline void printc(char c) {
_PRINT_[_PRINT_POS_++] = c;
if (_PRINT_POS_ == MAXR) ioflush();
}
inline void prints(char *s) {
for (int i = 0; s[i]; i++) printc(s[i]);
}
template<typename T> inline void print(T x, char c = '\n') {
if (x < 0) printc('-'), x = -x;
if (x) {
static char sta[20];
register int tp = 0;
for (; x; x /= 10) sta[tp++] = x % 10 + '0';
while (tp > 0) printc(sta[--tp]);
} else printc('0');
printc(c);
}
template<typename T1, typename ...T2> inline void print(T1 x, T2... y) {
print(x, ' '), print(y...);
}
}
using namespace IOStream;
using namespace std;
typedef long long ll;
typedef pair<int, int> P;
#define cls(a) memset(a, 0, sizeof(a))
const int MAXN = 100005, MAXM = 200005, MOD = 1000000007;
struct Graph { int to, next; } gra[MAXM];
struct Edge { int to, val, next; } edge[MAXM];
int hd[MAXN], st[20][MAXM], beg[MAXN], dep[MAXN], sta[MAXN], ed[MAXN];
int lg[MAXM], head[MAXN], arr[MAXN], vis[MAXN], sz[MAXN], n, m, tot;
void addgra(int u, int v) {
gra[++tot] = (Graph) { v, hd[u] };
hd[u] = tot;
}
void addedge(int u, int v, int w) {
edge[++tot] = (Edge) { v, w, head[u] };
head[u] = tot;
edge[++tot] = (Edge) { u, w, head[v] };
head[v] = tot;
//printf("%d %d %d\n", u, v, w);
}
void dfs1(int u, int fa) {
dep[st[0][beg[u] = ++tot] = u] = dep[fa] + 1;
sz[u] = 1;
for (int i = hd[u]; i; i = gra[i].next) {
int v = gra[i].to;
if (v != fa) dfs1(v, st[0][++tot] = u), sz[u] += sz[v];
}
ed[u] = tot;
}
int get_min(int a, int b) { return dep[a] < dep[b] ? a : b; }
int get_lca(int a, int b) {
a = beg[a], b = beg[b];
if (a > b) swap(a, b);
int l = lg[b - a + 1];
return get_min(st[l][a], st[l][b - (1 << l) + 1]);
}
bool cmp(const int &a, const int &b) { return beg[a] < beg[b]; }
int q, r, mm;
ll C[305][305], f[305], fac[305], rev[305];
ll modpow(ll a, int b) {
ll res = 1;
for (; b; b >>= 1) {
if (b & 1) res = res * a % MOD;
a = a * a % MOD;
}
return res;
}
void dfs4(int u, int fa) {
for (int &i = head[u]; i; i = edge[i].next) {
int v = edge[i].to;
if (v != fa) dfs4(v, u);
}
}
void dfs3(int u, int fa, int d, ll &ff) {
for (int i = head[u]; i; i = edge[i].next) {
int v = edge[i].to;
if (v == fa) continue;
dfs3(v, u, d - vis[u], ff);
}
if (vis[u]) (ff *= d) %= MOD;
}
int main() {
C[0][0] = 1;
for (int i = fac[0] = 1; i <= 300; i++) {
fac[i] = fac[i - 1] * i % MOD;
C[i][0] = 1;
for (int j = 1; j <= i; j++)
C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % MOD;
}
rev[300] = modpow(fac[300], MOD - 2);
for (int i = 300; i > 0; i--) rev[i - 1] = rev[i] * i % MOD;
read(n, m);
for (int i = 1; i < n; i++) {
int u, v; read(u, v);
addgra(u, v);
addgra(v, u);
}
dfs1(1, tot = 0);
for (int i = 2; i <= tot; i++) lg[i] = lg[i >> 1] + 1;
for (int i = 1; i < 20; i++)
for (int j = 1; j + (1 << i) - 1 <= tot; j++)
st[i][j] = get_min(st[i - 1][j], st[i - 1][j + (1 << i >> 1)]);
while (m--) {
int top = tot = 0, flag = 0; read(q, mm, r);
for (int i = 1; i <= q; i++) {
read(arr[i]), vis[hd[i] = arr[i]] = 1;
if (arr[i] == r) flag = 1;
}
if (!flag) arr[++q] = r;
sort(arr + 1, arr + 1 + q, cmp);
sta[++top] = 1;
for (int i = arr[1] == 1 ? 2 : 1; i <= q; i++) {
int l = get_lca(sta[top], arr[i]);
for (; top > 1 && dep[sta[top - 1]] > dep[l]; top--)
addedge(sta[top - 1], sta[top], dep[sta[top]] - dep[sta[top - 1]]);
if (dep[sta[top]] > dep[l]) addedge(l, sta[top], dep[sta[top]] - dep[l]), --top;
if (dep[sta[top]] < dep[l]) sta[++top] = l;
sta[++top] = arr[i];
}
for (; top > 1; top--) addedge(sta[top - 1], sta[top], dep[sta[top]] - dep[sta[top - 1]]);
ll res = 0;
for (int i = 1; i <= mm; i++) {
f[i] = 1;
dfs3(r, 0, i, f[i]);
ll sum = 0;
for (int j = 1; j <= i; j++) {
if ((i - j) & 1) (sum -= C[i][j] * f[j]) %= MOD;
else (sum += C[i][j] * f[j]) %= MOD;
}
(res += sum * rev[i]) %= MOD;
}
for (int i = 1; i <= q; i++) vis[arr[i]] = 0;
dfs4(r, 0);
print((res + MOD) % MOD);
}
ioflush();
return 0;
}