瞎bb
noip全真模拟赛又挂了。。
出题人居然又贺了三道原题。。
T3.走向巅峰新年的叶子//原题链接
被出题人魔改之后的题面…
T1暴力T2爆蛋,,于是只好来做T3
思路
树的多条直径一定会相交 所以我们用最暴力的做法(去考提高的应该都会吧 先随便选一个点 找到离这个点最远的一些点 作为直径的左端点们 在随便选一个左端点找到与她最远的一些点 也就是右端点们 然后再树上乱搞即可)算出这段区间的左右两个端点
p
0
,
p
1
p_0,p_1
p0,p1
- 如果 p 0 ≠ p 1 p_0\neq p_1 p0̸=p1,所以就会有上面讲的直径的左端点们和右端点们,于是就在叶子结点和一堆端点之间做期望dp!然后发现只会 θ ( n 2 ) \theta(n^2) θ(n2)的dp。。再见!!!!
- 第一种情况
n
2
n^2
n2暴力但第二种情况总会好考虑一些。吧?如果
p
0
=
p
1
p_0=p_1
p0=p1,类似于菊花图,在
n
n
n个叶子结点中选
m
m
m个(最后只能剩下一个直径的端点)简单的
O
(
n
)
O(n)
O(n)暴力计数(求
a
a
a个结点 还有
b
(
0
<
b
≤
a
)
b(0 < b \leq a)
b(0<b≤a)个结点未被染黑 再染黑一个的期望步数是
a
b
\frac{a}{b}
ba)!!!!!然而考后大佬又给出了反例。。。。
wa的一声就哭了。。这不是要炸的节奏吗??
所以上面那个思路是假的
真·思路
反正出题人也搬了原题 所以我也去学(hè)了题解
其实直径还有一个特别好的性质,就是树的每条直径的中点都是在同一个点上的(证明略,形象理解一下就行quq)
- 如果直径的长度是偶数 那么中点一定是在树上的某个点上的 我们只需要把这个点拎到root上 于是几个直径的端点(深度为 D 2 \frac{D}{2} 2D)就被划分到了几个不同的集合 窝门只需要各个区间求期望就好了
- 如果直径的长度是奇数 那中点不是在树边上了吗??其实没有关系我们假装那有个点就好了 于是类似于第一种情况 但是发现集合只剩下两个了
我们每次都枚举一个集合,算出其他集合全部被染黑需要的期望时间,再把这些期望时间加起来,就相当于全部的点被染黑了(集合数-1)次,所以窝门再把这个期望时间和
−
-
−染黑整个端点的集合的期望时间
×
\times
×(集合数-1),这个数就是
A
N
S
\mathcal{ANS}
ANS啦
最后再加一个特别重要的预处理:
∑
i
=
1
n
1
i
\sum_{i=1}^n\frac{1}{i}
∑i=1ni1的逆元
系不系简单粗暴又好打ヽ( ̄▽ ̄)ノ
Code
还有AC代码是从原来的zz代码魔改过来的 奇丑无比 所以大佬别打我
#include <cstdio>
#include <algorithm>
#define MOD 998244353
#define N 500005
using namespace std;
typedef long long LL;
struct Node {
int to, nxt;
}e[N << 1];
int cnt, lst[N], d[N], du[N], st[N], maxi, leaves, tot, d1[N];
LL pre_inv[N];
LL dp[N];
inline void add(int u, int v) {
e[++cnt].to = v;
e[cnt].nxt = lst[u];
lst[u] = cnt;
}
inline LL qui_pow(LL x, int y) {
if (y == 1) return x;
LL t = qui_pow(x, y / 2);
if (y & 1) return t * t % MOD * x % MOD;
else return t * t % MOD;
}
inline void dfs(int x, int fa, int dep) {
d[x] = dep;
if (d[x] > d[maxi]) maxi = x;
for (int i = lst[x]; i; i = e[i].nxt) {
if (e[i].to == fa) continue;
dfs(e[i].to, x, dep + 1);
}
}
inline int countt(int x, int fa, int len) {
if (du[x] == 1 && d[x] == len) return 1;
int sum = 0;
for (int i = lst[x]; i; i = e[i].nxt) {
if (e[i].to == fa) continue;
sum += countt(e[i].to, x, len);
}
return sum;
}
int main() {
int n, u, v, f = 0;
scanf("%d", &n);
for (int i = 1; i < n; ++i) {
scanf("%d%d", &u, &v);
du[u]++;
du[v]++;
add(u, v);
add(v, u);
}
LL inv;
for (int i = 1; i <= n; ++i) {
inv = qui_pow(i, MOD - 2);
pre_inv[i] = (pre_inv[i - 1] + inv) % MOD;
}
for (int i = 1; i <= n; ++i) {
if (du[i] == 1) leaves++;
}
maxi = 0;
dfs(1, 1, 0);
int x = maxi;
maxi = 0;
dfs(x, x, 0);
for (int i = 1; i <= n; ++i) {
d1[i] = d[i];
}
x = maxi;
maxi = 0;
dfs(x, x, 0);
int dia = d[maxi], mid, md, all = 0;
if (dia & 1) {
for (int i = 1; i <= n; ++i) {
if (d[i] == (dia >> 1) && d1[i] == (dia >> 1) + 1) mid = i;
if (d[i] == (dia >> 1) + 1 && d1[i] == (dia >> 1)) md = i;
}
// printf("%d %d\n", mid, md);
dfs(mid, mid, 0);
int num = countt(mid, md, (dia >> 1));
if (num > 0) st[++tot] = num;
all += num;
// printf("%d\n", num);
dfs(md, md, 0);
num = countt(md, mid, (dia >> 1));
if (num > 0) st[++tot] = num;
all += num;
// printf("%d\n", num);
}
else {
for (int i = 1; i <= n; ++i) {
if (d[i] == (dia >> 1) && d1[i] == (dia >> 1)) mid = i;
}
dfs(mid, mid, 0);
for (int i = lst[mid]; i; i = e[i].nxt) {
int num = countt(e[i].to, mid, (dia >> 1));
if (num > 0) st[++tot] = num;
all += num;
}
}
LL ans = 0;
for (int i = 1; i <= tot; ++i) {
ans += pre_inv[all - st[i]];
if (ans >= MOD) ans -= MOD;
}
ans -= 1LL * (tot - 1) * pre_inv[all] % MOD;
if (ans < 0) ans += MOD;
ans = ans * leaves % MOD;
printf("%lld\n", ans);
return 0;
}