链接:B-Endless Pallet_2021牛客国庆集训派对day6 (nowcoder.com)
非常套路(连环套路...)的一道题。
首先覆盖完集合内所有元素的条件很麻烦,所以考虑用min-max容斥转化。
设一个包含树上一些节点的集合S,定义该集合的最大值为集合内最后一个被覆盖点的被覆盖时间。再设T为包含所有节点的全集,那么原问题就是求T的最大值的期望。通过min-max容斥易得问题答案为
S最小值的期望就是在树上随机选链,第一次选到一条和S有交的链的期望时间。
这个东西非常好求,S中的点将整棵树划分为了一些块,那么与S没有交的充要条件就是在只各个块内部选链,方案数为sz * (sz + 1),然后对各个块求和得到总的无交集方案,除以总方案数n * (n + 1)后得到无交集概率,然后用1减去这个概率后取倒数就是所求期望,式子为:
其中num(S)是S这个集合划分出来的连通块。
然后题目中n等于50,显然不能暴力枚举子集,考虑在树上dp
首先要维护集合S元素数量奇偶性,开一维0/1
一般涉及到连通块构建的树上dp问题都要设一维正在构建的连通块大小,然后树上背包dp。
这个式子因为分母有一个减法,所以不能简单的只设正在构建连通块大小。
但是可以发现分母中后一项一定小于前一项,也就是最多n^2左右,不超过2600,那么可以直接将这一信息开成一维dp状态,记为sta。
因此共需开4维,dp[x][cur][op][sta]表示x节点,当前正在构建的连通块大小为cur,已经选择了在集合中的元素个数的奇偶性op,分母后一项的乘积和为sta的方案数。
转移就是dfs节点,dfs孩子之后合并,注意到暴力复杂度为n^2[树上背包部分] * (n^2)^2[sta合并部分]为n^6会T,因此要用fft优化sta的合并部分,复杂度变成n^4 * log(n^2),即可ac
细节:每个节点dp初值以及之后的转移,要考虑这个点是否放入S集合两种情况,代码中枚举到cur1 = 0要特判就是x点选入S集合中的情况。
#include <bits/stdc++.h>
#define pii pair<int,int>
#define fi first
#define sc second
#define pb push_back
#define ll long long
#define trav(v, x) for(auto v:x)
#define VI vector<int>
#define VLL vector<ll>
//define double long double
#define all(x) (x).begin(),(x).end()
using namespace std;
const double eps = 1e-10;//1e-12
const int N = 55;
const ll mod = 998244353;//1e9 + 7;
void Add(ll &x, ll y)
{
x = (x + y) % mod;
}
int n, mx_cur[N];
VI adj[N];
vector<ll> f[N][N][2];
vector<ll> bin[N][2];
ll qpow(ll x, ll y = mod - 2)
{
ll res = 1;
while(y)
{
if(y & 1)
res = res * x % mod;
x = x * x % mod;
y >>= 1;
}
return res;
}
int wh[5050], len, cc;
void ntt(vector<ll> &a, bool inv)
{
for(int i = 0; i < len; i++)
if(i < wh[i])swap(a[i], a[wh[i]]);
ll tp, mo, ha;
for(int l = 2, mid; l <= len; l <<= 1)
{
mid = l >> 1;
tp = qpow(3, (mod - 1) / l);
for(int i = 0; i < len; i += l)
{
mo = 1;
for(int j = 0; j < mid; j++, mo = mo * tp % mod)
{
ha = mo * a[i + j + mid] % mod;
a[i + j + mid] = (a[i + j] - ha + mod) % mod;
a[i + j] = (a[i + j] + ha) % mod;
}
}
}
if(inv)
{
tp= qpow(len, mod - 2);
for(int i = 1; i < len / 2; i++)
swap(a[i], a[len - i]);
for(int i = 0; i < len; i++)
a[i] = a[i] * tp % mod;
}
}
vector<ll> Mul(vector<ll> x, vector<ll> y)
{
cc = 0, len = 1;
while(len < x.size() + y.size())
len <<= 1, ++cc;
for(int i = 1; i <= len; i++)
wh[i] = (wh[i >> 1] >> 1) | ((i & 1) << (cc - 1));
int sz = x.size() + y.size() - 1;
x.resize(len), y.resize(len);
ntt(x, 0), ntt(y, 0);
for(int i = 0; i < len; i++)
x[i] = x[i] * y[i] % mod;
ntt(x, 1);
x.resize(sz);
return x;
}
void dfs(int x, int ff)
{
for(int cur = 0; cur <= n; cur++)
for(int op = 0; op <= 1; op++)
f[x][cur][op].clear();
f[x][1][0].resize(1);
f[x][0][1].resize(1);
f[x][1][0][0] = 1;
f[x][0][1][0] = 1;
mx_cur[x] = 1;
trav(v, adj[x])
{
if(v == ff)
continue;
dfs(v, x);
for(int cur = 0; cur <= mx_cur[x] + mx_cur[v]; cur++)
bin[cur][0].clear(), bin[cur][1].clear();
for(int cur1 = 0; cur1 <= mx_cur[x]; cur1++)
{
for(int op1 = 0; op1 <= 1; op1++)
{
if(!f[x][cur1][op1].size())
continue;
for(int cur2 = 0; cur2 <= mx_cur[v]; cur2++)
{
for(int op2 = 0; op2 <= 1; op2++)
{
if(!f[v][cur2][op2].size())
continue;
if(cur1 == 0)
{
int nw = f[x][cur1][op1].size() + f[v][cur2][op2].size() + cur2 * (cur2 + 1) - 1;
if(bin[0][(op1 + op2) & 1].size() < nw)
bin[0][(op1 + op2) & 1].resize(nw);
}
else
{
int nw = f[x][cur1][op1].size() + f[v][cur2][op2].size() - 1;
if(bin[cur1 + cur2][(op1 + op2) & 1].size() < nw)
bin[cur1 + cur2][(op1 + op2) & 1].resize(nw);
}
vector<ll> nw;
nw = Mul(f[x][cur1][op1], f[v][cur2][op2]);
if(cur1 == 0)
{
for(int sta = 0; sta < nw.size(); sta++)
Add(bin[0][(op1 + op2) & 1][sta + cur2 * (cur2 + 1)], nw[sta]);
}
else
{
for(int sta = 0; sta < nw.size(); sta++)
Add(bin[cur1 + cur2][(op1 + op2) & 1][sta], nw[sta]);
}
/*
for(int sta1 = 0; sta1 < f[x][cur1][op1].size(); sta1++)
{
for(int sta2 = 0; sta2 < f[v][cur2][op2].size(); sta2++)
{
int nw = f[x][cur1][op1][sta1];
if(cur1 == 0)
{
Add(bin[0][(op1 + op2) & 1][sta1 + sta2 + cur2 * (cur2 + 1)], 1LL * nw * f[v][cur2][op2][sta2]);
continue;
}
Add(bin[cur1 + cur2][(op1 + op2) & 1][sta1 + sta2], 1LL * nw * f[v][cur2][op2][sta2]);
}
}
*/
}
}
}
}
mx_cur[x] += mx_cur[v];
for(int cur = 0; cur <= mx_cur[x]; cur++)
{
for(int op = 0; op <= 1; op++)
{
f[x][cur][op] = bin[cur][op];
}
}
}
// cerr << x << '\n';
// for(int cur = 0; cur <= mx_cur[x]; cur++)
// {
// for(int op = 0; op <= 1; op++)
// {
// for(int j = 0; j < f[x][cur][op].size(); j++)
// cerr << cur << ' ' << op << ' ' << j << ' ' << f[x][cur][op][j] << '\n';
// }
// }
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);
int tt;
cin >> tt;
int cas = 0;
while(tt--)
{
++cas;
cin >> n;
for(int i = 1; i <= n; i++)
adj[i].clear();
for(int i = 1; i < n; i++)
{
int x, y;
cin >> x >> y;
adj[x].pb(y);
adj[y].pb(x);
}
dfs(1, 0);
//cerr << "??" << '\n';
f[1][0][0].resize(n * (n + 1) + 5);
f[1][0][1].resize(n * (n + 1) + 5);
for(int cur = 1; cur <= mx_cur[1]; cur++)
for(int op = 0; op <= 1; op++)
for(int sta = 0; sta < f[1][cur][op].size(); sta++)
if(f[1][cur][op][sta])Add(f[1][0][op][sta + cur * (cur + 1)], f[1][cur][op][sta]);
ll ans = 0;
for(int op = 0; op <= 1; op++)
{
for(int sta = 0; sta < n * (n + 1) && sta < f[1][0][op].size(); sta++)
{
Add(ans, 1LL * n * (n + 1) * qpow(n * (n + 1) - sta) % mod * f[1][0][op][sta] % mod * (op ? 1 : (mod - 1)) % mod);
}
}
cout << "Case #" << cas << ':' << ' ';
cout << ans << '\n';
}
}
/*
1
8
1 2
1 3
1 4
2 5
2 8
5 6
5 7
*/