[Codeforces 1111E] Tree(虚树+二项式反演)

6 篇文章 0 订阅
3 篇文章 0 订阅

题目链接

题目大意

给定一棵树,有一些询问。每次询问给出 k k k个点和两个数 m , r m,r m,r,表示让原树以 r r r为根,把这 k k k个点分成至多 m m m组,每组内不存在一个点是另一个点的祖先。求方案数膜1000000007.
n , Q ≤ 1 0 5 , ∑ k ≤ 1 0 5 , m ≤ m i n ( k , 300 ) n,Q\le 10^5,\sum k\le 10^5,m\le min(k,300) n,Q105,k105,mmin(k,300)

题解

显然先建虚树,并且按照给定根重新遍历虚树。刚开始SB的我想了好久怎么重新确定虚树中谁是谁的祖先……后来才发现直接把 r r r加进去一起建虚树就行了qaq。
然后,看数据范围似乎是个 O ( k m ) O(km) O(km)的做法?想了一会儿树形dp,感觉不太可行。那就估计是组合数学了。
先不考虑组与组之间无区别的问题(即两组分别为{1},{2}和{2},{1}实际上是相同的情况),我们给每个组设定一个编号。遍历虚树,如果某个点向上有 x x x个祖先,那么它可以选的编号有 m − x m-x mx种,乘起来即可。
显然这样会重复,我们考虑去重。不妨令 f ( m ) f(m) f(m)表示刚刚算出的答案, g ( m ) g(m) g(m)表示恰好分成 m m m非空无区别组的方案数。那么:
f ( m ) = ∑ i = 1 m ( m i ) g ( i ) ⋅ i ! f(m)=\sum_{i=1}^m\binom mi g(i)\cdot i! f(m)=i=1m(im)g(i)i!
二项式反演即可得到:
g ( m ) = 1 i ! ∑ i = 1 m ( − 1 ) m − i ( m i ) f ( i ) g(m)=\frac{1}{i!}\sum_{i=1}^m(-1)^{m-i}\binom mi f(i) g(m)=i!1i=1m(1)mi(im)f(i)
于是我们可以在 O ( k m ) O(km) O(km)的时间内算出所有的 f f f,利用 f f f O ( m 2 ) ≤ O ( k m ) O(m^2)\le O(km) O(m2)O(km)的时间内算出所有的 g g g,直接求和就是答案。

#include <bits/stdc++.h>
namespace IOStream {
	const int MAXR = 1 << 23;
	char _READ_[MAXR], _PRINT_[MAXR];
	int _READ_POS_, _PRINT_POS_, _READ_LEN_;
	inline char readc() {
	#ifndef ONLINE_JUDGE
		return getchar();
	#endif
		if (!_READ_POS_) _READ_LEN_ = fread(_READ_, 1, MAXR, stdin);
		char c = _READ_[_READ_POS_++];
		if (_READ_POS_ == MAXR) _READ_POS_ = 0;
		if (_READ_POS_ > _READ_LEN_) return 0;
		return c;
	}
	template<typename T> inline void read(T &x) {
		x = 0; register int flag = 1, c;
		while (((c = readc()) < '0' || c > '9') && c != '-');
		if (c == '-') flag = -1; else x = c - '0';
		while ((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
		x *= flag;
	}
	template<typename T1, typename ...T2> inline void read(T1 &a, T2 &...x) {
		read(a), read(x...);
	}
	inline int reads(char *s) {
		register int len = 0, c;
		while (isspace(c = readc()) || !c);
		s[len++] = c;
		while (!isspace(c = readc()) && c) s[len++] = c;
		s[len] = 0;
		return len;
	}
	inline void ioflush() {
		fwrite(_PRINT_, 1, _PRINT_POS_, stdout), _PRINT_POS_ = 0;
		fflush(stdout);
	}
	inline void printc(char c) {
		_PRINT_[_PRINT_POS_++] = c;
		if (_PRINT_POS_ == MAXR) ioflush();
	}
	inline void prints(char *s) {
		for (int i = 0; s[i]; i++) printc(s[i]);
	}
	template<typename T> inline void print(T x, char c = '\n') {
		if (x < 0) printc('-'), x = -x;
		if (x) {
			static char sta[20];
			register int tp = 0;
			for (; x; x /= 10) sta[tp++] = x % 10 + '0';
			while (tp > 0) printc(sta[--tp]);
		} else printc('0');
		printc(c);
	}
	template<typename T1, typename ...T2> inline void print(T1 x, T2... y) {
		print(x, ' '), print(y...);
	}
}
using namespace IOStream;
using namespace std;
typedef long long ll;
typedef pair<int, int> P;
#define cls(a) memset(a, 0, sizeof(a))

const int MAXN = 100005, MAXM = 200005, MOD = 1000000007;
struct Graph { int to, next; } gra[MAXM];
struct Edge { int to, val, next; } edge[MAXM];
int hd[MAXN], st[20][MAXM], beg[MAXN], dep[MAXN], sta[MAXN], ed[MAXN];
int lg[MAXM], head[MAXN], arr[MAXN], vis[MAXN], sz[MAXN], n, m, tot;
void addgra(int u, int v) {
    gra[++tot] = (Graph) { v, hd[u] };
    hd[u] = tot;
}
void addedge(int u, int v, int w) {
    edge[++tot] = (Edge) { v, w, head[u] };
    head[u] = tot;
    edge[++tot] = (Edge) { u, w, head[v] };
    head[v] = tot;
    //printf("%d %d %d\n", u, v, w);
}
void dfs1(int u, int fa) {
    dep[st[0][beg[u] = ++tot] = u] = dep[fa] + 1;
    sz[u] = 1;
    for (int i = hd[u]; i; i = gra[i].next) {
        int v = gra[i].to;
        if (v != fa) dfs1(v, st[0][++tot] = u), sz[u] += sz[v];
    }
    ed[u] = tot;
}
int get_min(int a, int b) { return dep[a] < dep[b] ? a : b; }
int get_lca(int a, int b) {
    a = beg[a], b = beg[b];
    if (a > b) swap(a, b);
    int l = lg[b - a + 1];
    return get_min(st[l][a], st[l][b - (1 << l) + 1]);
}
bool cmp(const int &a, const int &b) { return beg[a] < beg[b]; }
int q, r, mm;
ll C[305][305], f[305], fac[305], rev[305];
ll modpow(ll a, int b) {
	ll res = 1;
	for (; b; b >>= 1) {
		if (b & 1) res = res * a % MOD;
		a = a * a % MOD;
	}
	return res;
}
void dfs4(int u, int fa) {
	for (int &i = head[u]; i; i = edge[i].next) {
		int v = edge[i].to;
		if (v != fa) dfs4(v, u);
	}
}
void dfs3(int u, int fa, int d, ll &ff) {
	for (int i = head[u]; i; i = edge[i].next) {
		int v = edge[i].to;
		if (v == fa) continue;
		dfs3(v, u, d - vis[u], ff);
	}
	if (vis[u]) (ff *= d) %= MOD;
}
int main() {
	C[0][0] = 1;
	for (int i = fac[0] = 1; i <= 300; i++) {
		fac[i] = fac[i - 1] * i % MOD;
		C[i][0] = 1;
		for (int j = 1; j <= i; j++)
			C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % MOD;
	}
	rev[300] = modpow(fac[300], MOD - 2);
	for (int i = 300; i > 0; i--) rev[i - 1] = rev[i] * i % MOD;
    read(n, m);
    for (int i = 1; i < n; i++) {
        int u, v; read(u, v);
        addgra(u, v);
        addgra(v, u);
    }
    dfs1(1, tot = 0);
    for (int i = 2; i <= tot; i++) lg[i] = lg[i >> 1] + 1;
    for (int i = 1; i < 20; i++)
    for (int j = 1; j + (1 << i) - 1 <= tot; j++)
        st[i][j] = get_min(st[i - 1][j], st[i - 1][j + (1 << i >> 1)]);
    while (m--) {
        int top = tot = 0, flag = 0; read(q, mm, r);
        for (int i = 1; i <= q; i++) {
        	read(arr[i]), vis[hd[i] = arr[i]] = 1;
        	if (arr[i] == r) flag = 1;
        }
        if (!flag) arr[++q] = r;
        sort(arr + 1, arr + 1 + q, cmp);
        sta[++top] = 1;
        for (int i = arr[1] == 1 ? 2 : 1; i <= q; i++) {
            int l = get_lca(sta[top], arr[i]);
            for (; top > 1 && dep[sta[top - 1]] > dep[l]; top--)
                addedge(sta[top - 1], sta[top], dep[sta[top]] - dep[sta[top - 1]]);
            if (dep[sta[top]] > dep[l]) addedge(l, sta[top], dep[sta[top]] - dep[l]), --top;
            if (dep[sta[top]] < dep[l]) sta[++top] = l;
            sta[++top] = arr[i];
        }
        for (; top > 1; top--) addedge(sta[top - 1], sta[top], dep[sta[top]] - dep[sta[top - 1]]);
        ll res = 0;
        for (int i = 1; i <= mm; i++) {
        	f[i] = 1;
        	dfs3(r, 0, i, f[i]);
        	ll sum = 0;
        	for (int j = 1; j <= i; j++) {
        		if ((i - j) & 1) (sum -= C[i][j] * f[j]) %= MOD;
        		else (sum += C[i][j] * f[j]) %= MOD;
        	}
        	(res += sum * rev[i]) %= MOD;
        }
        for (int i = 1; i <= q; i++) vis[arr[i]] = 0;
        dfs4(r, 0);
        print((res + MOD) % MOD);
    }
    ioflush();
    return 0;
}
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值