【CodeForces】Avito Code Challenge 2018 (Div. 1 + Div. 2) 题解

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_39972971/article/details/83385345

【比赛链接】

【题解链接】

**【A】**Antipalindrome

【思路要点】

  • 当所有字符相同,答案为 00
  • 否则,若原串为回文串,删去其最后一个字符一定会使其变成非回文串,因此答案为 N1N-1 ,否则答案为 NN
  • 时间复杂度 O(N)O(N)

【代码】

#include<bits/stdc++.h>
using namespace std;
const int MAXN = 55;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); } 
template <typename T> void read(T &x) {
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
template <typename T> void write(T x) {
	if (x < 0) x = -x, putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
	write(x);
	puts("");
}
int n; char s[MAXN];
bool check() {
	for (int i = 1, j = n; i <= j; i++, j--)
		if (s[i] != s[j]) return false;
	return true;
}
int main() {
	scanf("%s", s + 1);
	n = strlen(s + 1);
	bool flg = true;
	for (int i = 1; i <= n; i++)
		flg &= s[i] == s[1];
	if (flg) printf("%d\n", 0);
	else if (check()) printf("%d\n", n - 1);
	else printf("%d\n", n);
	return 0;
}

**【B】**Businessmen Problems

【思路要点】

  • std::mapstd::map 实现取每一种展品的最大值。
  • 时间复杂度 O(NLogN)O(NLogN)

【代码】

#include<bits/stdc++.h>
using namespace std;
const int MAXN = 2e5 + 5;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); } 
template <typename T> void read(T &x) {
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
template <typename T> void write(T x) {
	if (x < 0) x = -x, putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
	write(x);
	puts("");
}
map <int, int> mp;
int main() {
	ll ans = 0;
	int n; read(n);
	for (int i = 1; i <= n; i++) {
		int x, y; read(x), read(y);
		ans += y; mp[x] = y;
	}
	int m; read(m);
	for (int i = 1; i <= m; i++) {
		int x, y; read(x), read(y);
		ans += y; ans -= min(mp[x], y);
	}
	writeln(ans);
	return 0;
}

**【C】**Useful Decomposition

【思路要点】

  • 唯一可能的合法情况是所有路径均交于一点。
  • 取度数最大的点作为该点, dfsdfs 构造路径即可。
  • 时间复杂度 O(N)O(N)

【代码】

#include<bits/stdc++.h>
using namespace std;
const int MAXN = 2e5 + 5;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); } 
template <typename T> void read(T &x) {
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
template <typename T> void write(T x) {
	if (x < 0) x = -x, putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
	write(x);
	puts("");
}
int n, root;
vector <int> a[MAXN];
int dfs(int pos, int fa) {
	int dest = 0;
	for (auto x : a[pos])
		if (x != fa) {
			if (dest) {
				printf("No\n");
				exit(0);
			} else dest = x;
		}
	if (dest == 0) return pos;
	else return dfs(dest, pos);
}
int main() {
	read(n);
	for (int i = 1; i <= n - 1; i++) {
		int x, y; read(x), read(y);
		a[x].push_back(y);
		a[y].push_back(x);
	}
	root = 1;
	for (int i = 2; i <= n; i++)
		if (a[i].size() > a[root].size()) root = i;
	vector <int> ans;
	for (auto x : a[root])
		ans.push_back(dfs(x, root));
	printf("Yes\n");
	writeln(ans.size());
	for (auto x : ans)
		printf("%d %d\n", root, x);
	return 0;
}

**【D】**Bookshelves

【思路要点】

  • 逐位确定答案。
  • 对于答案 ansans ,我们需要判断是否能将序列分成区间和 sumisum_i 满足 sumi&amp;ans=anssum_i\&amp;ans=anskk 段,可以用动态规划判断。
  • dpi,jdp_{i,j} 表示能否将序列的前 ii 个元素划分成 jj 个合法区间,枚举下一个区间的右端点进行转移。
  • 时间复杂度 O(N2KLogV)O(N^2KLogV)

【代码】

#include<bits/stdc++.h>
using namespace std;
const int MAXN = 55;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); } 
template <typename T> void read(T &x) {
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
template <typename T> void write(T x) {
	if (x < 0) x = -x, putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
	write(x);
	puts("");
}
int n, k;
ll ans, s[MAXN], a[MAXN];
bool check(ll ans) {
	static bool dp[MAXN][MAXN];
	memset(dp, false, sizeof(dp));
	dp[0][0] = true;
	for (int i = 0; i <= n - 1; i++)
	for (int j = 0; j <= k - 1; j++)
		if (dp[i][j]) {
			for (int p = i + 1; p <= n; p++)
				if (((s[p] - s[i]) & ans) == ans) dp[p][j + 1] = true;
		}
	return dp[n][k];
}
int main() {
	read(n), read(k);
	for (int i = 1; i <= n; i++)
		read(a[i]), s[i] = s[i - 1] + a[i];
	for (int i = 60; i >= 0; i--) {
		ans += 1ll << i;
		if (!check(ans)) ans -= 1ll << i;
	}
	writeln(ans);
	return 0;
}

**【E】**Addition on Segments

【思路要点】

  • 若我们确定了区间最大值的位置,显然我们可以选择不进行不包含这个位置的操作来确保其成为区间最大值。
  • 因此,我们希望计算出包含每个位置的询问权值的背包,再合并得到答案。
  • 注意到如果用 bitsetbitset 来存储背包,我们可以在 O(Nw)O(\frac{N}{w}) 的时间内加入一个权值,但是不方便删除。
  • 因此,我们可以采用线段树分治的思想,将一个询问拆分为 O(LogN)O(LogN) 个询问,放在线段树上,这样,从根节点 dfsdfs 到某个叶子节点的路径上的所有询问即为包含该位置的询问,据此进行转移即可。
  • 时间复杂度 O(NQLogNw)O(\frac{NQLogN}{w})

【代码】

#include<bits/stdc++.h>
using namespace std;
const int MAXN = 1e4 + 5;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); } 
template <typename T> void read(T &x) {
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
template <typename T> void write(T x) {
	if (x < 0) x = -x, putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
	write(x);
	puts("");
}
struct SegmentTree {
	struct Node {
		int lc, rc;
		vector <int> tag;
	} a[MAXN * 2];
	int n, size, root;
	bitset <MAXN> ans;
	void build(int &root, int l, int r) {
		root = ++size;
		if (l == r) return;
		int mid = (l + r) / 2;
		build(a[root].lc, l, mid);
		build(a[root].rc, mid + 1, r);
	}
	void init(int x) {
		n = x;
		root = size = 0;
		build(root, 1, n);
	}
	void modify(int root, int l, int r, int ql, int qr, int x) {
		if (l == ql && r == qr) {
			a[root].tag.push_back(x);
			return;
		}
		int mid = (l + r) / 2;
		if (mid >= ql) modify(a[root].lc, l, mid, ql, min(qr, mid), x);
		if (mid + 1 <= qr) modify(a[root].rc, mid + 1, r, max(mid + 1, ql), qr, x);
	}
	void modify(int l, int r, int x) {
		if (l > r) return;
		else modify(root, 1, n, l, r, x);
	}
	void work(int pos, bitset <MAXN> now) {
		for (auto x : a[pos].tag)
			now |= now << x;
		if (a[pos].lc == 0) {
			ans |= now;
			return;
		}
		work(a[pos].lc, now);
		work(a[pos].rc, now);
	}
	void getans() {
		work(root, 1);
		vector <int> finalans;
		for (int i = 1; i <= n; i++)
			if (ans[i]) finalans.push_back(i);
		writeln(finalans.size());
		for (auto x : finalans)
			printf("%d ", x);
	}
} ST;
int n, m;
int main() {
	read(n), read(m);
	ST.init(n);
	for (int i = 1; i <= m; i++) {
		int l, r, x;
		read(l), read(r), read(x);
		ST.modify(l, r, x);
	}
	ST.getans();
	return 0;
}

**【F】**Round Marriage

【思路要点】

  • 首先二分答案,现在我们只需要判断某个答案 midmid 是否合法。
  • 不妨令 a1a2a3...aN,b1b2b3...bNa_1≤a_2≤a_3≤...≤a_N,b_1≤b_2≤b_3≤...≤b_N ,一旦确定 a1a_1 对应的 bib_i ,最优的对应方案就一定是 a1bi,a2bi+1,...,aNbi1a_1-b_i,a_2-b_{i+1},...,a_N-b_{i-1} ,否则,将答案调整至这样不会变劣。
  • 每一个 aia_i 都有一个能够对应的 bib_i 区间,每一个区间都会对 a1a_1 可能的对应位置作出限制,考虑完所有 aia_i 都判断 a1a_1 可能的对应位置是否为空即可。
  • 时间复杂度 O(NLogV)O(NLogV)

【代码】

#include<bits/stdc++.h>
using namespace std;
const int MAXN = 2e5 + 5;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); } 
template <typename T> void read(T &x) {
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
template <typename T> void write(T x) {
	if (x < 0) x = -x, putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
	write(x);
	puts("");
}
int n, len;
int a[MAXN], b[MAXN];
int calc(int pos) {
	int ans = 0;
	if (pos < 1) pos += n, ans -= len;
	if (pos > n) pos -= n, ans += len;
	ans += b[pos];
	return ans;
}
bool check(int mid) {
	int ql = -n, qr = n;
	for (int i = 1; i <= n; i++) {
		while (qr >= ql && abs(a[i] - calc(i + ql)) > mid) ql++;
		while (qr >= ql && abs(a[i] - calc(i + qr)) > mid) qr--;
		if (qr < ql) return false;
	}
	return true;
}
int main() {
	read(n), read(len);
	for (int i = 1; i <= n; i++)
		read(a[i]);
	for (int i = 1; i <= n; i++)
		read(b[i]);
	sort(a + 1, a + n + 1);
	sort(b + 1, b + n + 1);
	int l = 0, r = len;
	while (l < r) {
		int mid = (l + r) / 2;
		if (check(mid)) r = mid;
		else l = mid + 1;
	}
	writeln(l);
	return 0;
}

**【G】**Magic multisets

【思路要点】

  • 用线段树维护区间 multisetmultiset 大小。
  • 对于一个操作 l r xl\ r\ x,已经有 xx 的位置会 2*2 ,没有 xx 的位置会 +1+1
  • 因此,我们对于每一种数字,用 std::setstd::set 维护已经存在该数的极大区间,操作时将操作区间分成若干段,分别在线段树上操作即可,操作总数为 O(Q)O(Q)
  • 时间复杂度 O(QLogN)O(QLogN)

【代码】

#include<bits/stdc++.h>
using namespace std;
const int MAXN = 2e5 + 5;
const int P = 998244353;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); } 
template <typename T> void read(T &x) {
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
template <typename T> void write(T x) {
	if (x < 0) x = -x, putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
	write(x);
	puts("");
}
void adds(int &x, int y) {
	x += y;
	if (x >= P) x -= P;
}
void times(int &x, int y) {
	x = 1ll * x * y % P;
}
struct SegmentTree {
	struct Node {
		int lc, rc, len;
		int sum, k, b;
	} a[MAXN * 2];
	int n, size, root;
	void update(int root) {
		a[root].sum = a[a[root].lc].sum + a[a[root].rc].sum;
		if (a[root].sum >= P) a[root].sum -= P;
	}
	void build(int &root, int l, int r) {
		root = ++size;
		a[root].len = r - l + 1;
		a[root].k = 1;
		a[root].b = 0;
		a[root].sum = 0;
		if (l == r) return;
		int mid = (l + r) / 2;
		build(a[root].lc, l, mid);
		build(a[root].rc, mid + 1, r);
	}
	void init(int x) {
		n = x;
		root = size = 0;
		build(root, 1, n);
	}
	void pushdown(int root) {
		if (a[root].k == 1 && a[root].b == 0) return;
		times(a[a[root].lc].k, a[root].k);
		times(a[a[root].lc].b, a[root].k);
		times(a[a[root].lc].sum, a[root].k);
		adds(a[a[root].lc].b, a[root].b);
		adds(a[a[root].lc].sum, 1ll * a[root].b * a[a[root].lc].len % P);
		times(a[a[root].rc].k, a[root].k);
		times(a[a[root].rc].b, a[root].k);
		times(a[a[root].rc].sum, a[root].k);
		adds(a[a[root].rc].b, a[root].b);
		adds(a[a[root].rc].sum, 1ll * a[root].b * a[a[root].rc].len % P);
		a[root].k = 1, a[root].b = 0;
	}
	void modify(int root, int l, int r, int ql, int qr, int k, int b) {
		if (l == ql && r == qr) {
			times(a[root].k, k);
			times(a[root].b, k);
			times(a[root].sum, k);
			adds(a[root].b, b);
			adds(a[root].sum, 1ll * b * a[root].len % P);
			return;
		}
		pushdown(root);
		int mid = (l + r) / 2;
		if (mid >= ql) modify(a[root].lc, l, mid, ql, min(qr, mid), k, b);
		if (mid + 1 <= qr) modify(a[root].rc, mid + 1, r, max(mid + 1, ql), qr, k, b);
		update(root);
	}
	void modify(int l, int r, int k, int b) {
		if (l > r) return;
		else modify(root, 1, n, l, r, k, b);
	}
	int query(int root, int l, int r, int ql, int qr) {
		if (l == ql && r == qr) return a[root].sum;
		int ans = 0;
		pushdown(root);
		int mid = (l + r) / 2;
		if (mid >= ql) ans += query(a[root].lc, l, mid, ql, min(mid, qr));
		if (mid + 1 <= qr) ans += query(a[root].rc, mid + 1, r, max(mid + 1, ql), qr);
		return ans % P;
	}
	int query(int l, int r) {
		if (l > r) return 0;
		else return query(root, 1, n, l, r);
	}
} ST;
int n, q;
set <pair <int, int> > st[MAXN];
int main() {
	read(n), read(q);
	ST.init(n);
	for (int i = 1; i <= q; i++) {
		int opt, l, r, val;
		read(opt), read(l), read(r);
		if (opt == 2) writeln(ST.query(l, r));
		else {
			read(val);
			int fl = l, fr = r, last = l;
			set <pair <int, int> > :: iterator tmp = st[val].lower_bound(make_pair(l, l)), tnp = tmp;
			if (tmp != st[val].begin()) {
				tnp--;
				if ((*tnp).second >= l) {
					if ((*tnp).second >= r) {
						ST.modify(l, r, 2, 0);
						continue;
					}
					fl = (*tnp).first;
					ST.modify(last, (*tnp).second, 2, 0);
					last = (*tnp).second + 1;
					st[val].erase(tnp);
				}
			}
			while (tmp != st[val].end() && (*tmp).second <= r) {
				ST.modify(last, (*tmp).first - 1, 1, 1);
				ST.modify((*tmp).first, (*tmp).second, 2, 0);
				last = (*tmp).second + 1;
				tnp = tmp; tmp++;
				st[val].erase(tnp);
			}
			if (tmp != st[val].end() && (*tmp).first <= r) {
				fr = (*tmp).second;
				ST.modify(last, (*tmp).first - 1, 1, 1);
				ST.modify((*tmp).first, r, 2, 0);
				st[val].erase(tmp);
			} else ST.modify(last, r, 1, 1);
			st[val].insert(make_pair(fl, fr));
		}
	}
	return 0;
}

**【H】**K Paths

【思路要点】

  • 首先考虑一种暴力:枚举所有路径的交,它显然也是一条路径,令其为 (x,y)(x,y)
  • (x,y)(x,y) 对答案的贡献显然应当是由独立的两部分相乘得到,不妨令其为 Ansx,yAnsy,xAns_{x,y}*Ans_{y,x}Ansx,yAns_{x,y} 的意义是 xx 为根的树中删去 yy 所在的子树后,选择 kk 条只在 xx 处相交的路径的方案数。
  • 容易看出本质不同的 Ansi,jAns_{i,j} 只有 2(N1)2*(N-1) 个,只要我们分别计算出了这些 Ansi,jAns_{i,j} ,上述过程显然可以通过树形 DPDP 优化。
  • 假设 xx 为根的树中删去 yy 所在的子树后,剩余的子树大小分别为 s1,s2,...,sns_1,s_2,...,s_n ,计算多项式 P(x)=(s1x+1)(s2x+1)...(snx+1)P(x)=(s_1x+1)(s_2x+1)...(s_nx+1)Ansx,yAns_{x,y} 即为 i=0kai(ni)i!\sum_{i=0}^{k}a_i*\binom{n}{i}*i! ,其中 aia_i 表示 P(x)P(x)ii 次项的系数。
  • 对于点 xx ,我们连上 yy 所在子树的大小一起计算 P(x)P&#x27;(x) ,令 yy 所在子树的大小为 ss ,我们计算 Ansx,yAns_{x,y} 是需要用的 P(x)P(x) 即为 P(x)sx+1\frac{P&#x27;(x)}{sx+1} ,可以通过简单的多项式除单项式得到。
  • xx 的度数为 DD ,计算 P(x)P&#x27;(x) 使用分治 NTTNTT ,时间复杂度为 O(DLog2D)O(DLog^2D) 。注意到不同的子树大小至多有 O(D)O(\sqrt{D}) 个,因此后续计算的时间复杂度为 O(DD)O(D\sqrt{D})
  • 时间复杂度 O(NLog2N+NN+K)O(NLog^2N+N\sqrt{N}+K)

【代码】

#include<bits/stdc++.h>
using namespace std;
const int MAXN = 2e5 + 5;
const int P = 998244353;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); } 
template <typename T> void read(T &x) {
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
template <typename T> void write(T x) {
	if (x < 0) x = -x, putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
	write(x);
	puts("");
}
namespace NTT {
	const int MAXN = 262144;
	const int P = 998244353;
	const int G = 3;
	int power(int x, int y) {
		if (y == 0) return 1;
		int tmp = power(x, y / 2);
		if (y % 2 == 0) return 1ll * tmp * tmp % P;
		else return 1ll * tmp * tmp % P * x % P;
	}
	int N, Log, home[MAXN];
	void NTTinit() {
		for (int i = 0; i < N; i++) {
			int ans = 0, tmp = i;
			for (int j = 1; j <= Log; j++) {
				ans <<= 1;
				ans += tmp & 1;
				tmp >>= 1;
			}
			home[i] = ans;
		}
	}
	void NTT(int *a, int mode) {
		for (int i = 0; i < N; i++)
			if (home[i] < i) swap(a[i], a[home[i]]);
		for (int len = 2; len <= N; len <<= 1) {
			int delta;
			if (mode == 1) delta = power(G, (P - 1) / len);
			else delta = power(G, P - 1 - (P - 1) / len);
			for (int i = 0; i < N; i += len) {
				int now = 1;
				for (int j = i, k = i + len / 2; k < i + len; j++, k++) {
					int tmp = a[j];
					int tnp = 1ll * a[k] * now % P;
					a[j] = (tmp + tnp) % P;
					a[k] = (tmp - tnp + P) % P;
					now = 1ll * now * delta % P;
				}
			}
		}
		if (mode == -1) {
			int inv = power(N, P - 2);
			for (int i = 0; i < N; i++)
				a[i] = 1ll * a[i] * inv % P;
		}
	}
	vector <int> times(vector <int> a, vector <int> b) {
		N = 1, Log = 0;
		int limit = a.size() + b.size() - 1;
		int sa = a.size(), sb = b.size();
		while (N < limit) {
			N <<= 1;
			Log++;
		}
		static int tmp[MAXN], tnp[MAXN];
		for (int i = 0; i < N; i++) {
			if (i < sa) tmp[i] = a[i];
			else tmp[i] = 0;
			if (i < sb) tnp[i] = b[i];
			else tnp[i] = 0;
		}
		NTTinit();
		NTT(tmp, 1);
		NTT(tnp, 1);
		for (int i = 0; i < N; i++)
			tmp[i] = 1ll * tmp[i] * tnp[i] % P;
		NTT(tmp, -1);
		vector <int> ans;
		ans.resize(limit);
		for (unsigned i = 0; i < ans.size(); i++)
			ans[i] = tmp[i];
		return ans;
	}
}
map <int, int> ans[MAXN];
vector <int> a[MAXN], b[MAXN];
int n, m, size[MAXN], depth[MAXN], sum[MAXN], finalans;
void dfs(int pos, int fa) {
	size[pos] = 1;
	depth[pos] = depth[fa] + 1;
	for (unsigned i = 0; i < a[pos].size(); i++)
		if (a[pos][i] != fa) {
			dfs(a[pos][i], pos);
			size[pos] += size[a[pos][i]];
		}
}
void update(int &x, int y) {
	x += y;
	if (x >= P) x -= P;
}
vector <int> work(int from, int l, int r) {
	if (l == r) {
		vector <int> ans;
		ans.push_back(1);
		ans.push_back(b[from][l]);
		return ans;
	}
	int mid = (l + r) / 2;
	return NTT :: times(work(from, l, mid), work(from, mid + 1, r));
}
int fac[MAXN], inv[MAXN], vin[MAXN];
int power(int x, int y) {
	if (y == 0) return 1;
	int tmp = power(x, y / 2);
	if (y % 2 == 0) return 1ll * tmp * tmp % P;
	else return 1ll * tmp * tmp % P * x % P;
}
int getc(int x, int y) {
	if (y > x) return 0;
	else return 1ll * fac[x] * inv[y] % P * inv[x - y] % P;
}
void init(int n) {
	fac[0] = 1;
	for (int i = 1; i <= n; i++)
		fac[i] = 1ll * fac[i - 1] * i % P;
	inv[n] = power(fac[n], P - 2);
	for (int i = n - 1; i >= 0; i--)
		inv[i] = inv[i + 1] * (i + 1ll) % P;
	for (int i = 1; i <= n; i++)
		vin[i] = power(i, P - 2);
}
void getans(int pos, int fa) {
	sum[pos] = 0;
	for (unsigned i = 0; i < a[pos].size(); i++)
		if (a[pos][i] != fa) {
			getans(a[pos][i], pos);
			update(sum[pos], sum[a[pos][i]]);
		}
	int tmpans = 0;
	update(tmpans, 1ll * sum[pos] * sum[pos] % P);
	for (unsigned i = 0; i < a[pos].size(); i++)
		if (a[pos][i] != fa) {
			update(tmpans, P - 1ll * sum[a[pos][i]] * sum[a[pos][i]] % P);
			update(finalans, 1ll * sum[a[pos][i]] * ans[pos][b[pos][i]] % P);
		} else update(sum[pos], ans[pos][b[pos][i]]);
	update(finalans, 1ll * tmpans * vin[2] % P);
}
int main() {
	init(MAXN - 1);
	read(n), read(m);
	if (n == 1) {
		printf("%d\n", 0);
		return 0;
	}
	for (int i = 1; i <= n - 1; i++) {
		int x, y; read(x), read(y);
		a[x].push_back(y);
		a[y].push_back(x);
	}
	dfs(1, 0);
	for (int i = 1; i <= n; i++) {
		b[i].resize(a[i].size());
		for (unsigned j = 0; j < a[i].size(); j++) {
			int dest = a[i][j];
			if (depth[dest] > depth[i]) b[i][j] = size[dest];
			else b[i][j] = n - size[i];
		}
		vector <int> now = work(i, 0, b[i].size() - 1);
		for (unsigned j = 0; j < a[i].size(); j++) {
			int x = a[i][j], y = b[i][j];
			if (ans[i].count(y)) continue;
			if (m == 1) {
				ans[i][y] = 1;
				continue;
			}
			vector <int> tmp = now;
			vector <int> res (now.size() - 1);
			for (unsigned k = now.size() - 1; k >= 1; k--) {
				int tes = 1ll * tmp[k] * vin[y] % P;
				res[k - 1] = tes;
				update(tmp[k], P - 1ll * tes * size[x] % P);
				update(tmp[k - 1], P - tes);
			}
			int val = 0;
			for (unsigned k = 0; k < res.size(); k++)
				update(val, 1ll * res[k] * getc(m, k) % P * fac[k] % P);
			ans[i][y] = val;
		}
	}
	getans(1, 0);
	writeln(finalans);
	return 0;
}
阅读更多
换一批

没有更多推荐了,返回首页