AtCoder Beginner Contest 215 G - Colorful Candies 2
有n个糖果,每个糖果有着一个颜色a[i],每次拿k个糖果期望拿到E(x)个不同颜色的糖果,求出k从1~n分别得到的E(x)。最终答案mod998244353。
数据范围n <= 5e4
看到这个问题首先反应是答案和颜色的具体值以及出现的顺序无关,所以离散化之后用一个cnt数组记录他们分别的出现次数之后就可以不用去管a[i]了。
之后可以发现,要求出选k个糖果后不同颜色的糖果的期望值,就是求出每一种颜色的糖果至少出现一次的概率,然后求和。P(至少出现一次)=1-P(一次都不出现),而一次都不出现的概率,假如当前这个颜色的糖果有x个,那么就是 C n − x k / C n k C_{n-x}^{k}/C_n^k Cn−xk/Cnk。
如果就这样暴力求解的话,最终的复杂度最坏为 O ( n 2 ) O(n^2) O(n2)的,不过可以想到很明显的优化,一个是当k>n-x时一次都不出现的概率为0,所以可以将cnt的值进行排序然后从小到大枚举,一旦达到k>n-cnt[i]的条件就可以直接跳出循环。另一个就是相同是cnt进行去重,求出的答案直接乘以相同cnt的数量,由于 ∑ c n t = n \sum cnt=n ∑cnt=n,当满足cnt直接各不相同时,cnt个数为 O ( n ) O(\sqrt n) O(n)。这样总复杂度就是 O ( n n ) O(n\sqrt n) O(nn),时限给4秒很宽松。
#include <bits/stdc++.h>
#define f(i, l, r) for(int i = l; i <= r; i ++)
#define nf(i, r, l) for(int i = r; i >= l; i --)
typedef long long ll;
using namespace std;
const int N = 5e4 + 10, mod = 998244353;
int a[N], cnt[N], tot, ccnt[N];
long long jc[N], ny[N];
long long C(int n, int m, int num)
{
if (m == 0 || m == n)
return 1;
long long res = 1;
if (!num)
{
res = jc[n] * ny[m];
res %= mod;
res *= ny[n - m];
res %= mod;
}
else
{
res = jc[m] * jc[n - m];
res %= mod;
res *= ny[n];
res %= mod;
}
return res;
}
long long qpow(long long x, int y)
{
long long res = 1;
while (y)
{
if (y & 1)
{
res *= x;
res %= mod;
}
x *= x;
x %= mod;
y >>= 1;
}
return res;
}
void init(int n)
{
jc[0] = 1;
f(i, 1, n)
{
jc[i] = jc[i - 1] * i;
jc[i] %= mod;
}
f(i, 1, n)
ny[i] = qpow(jc[i], mod - 2);
}
int main()
{
#ifdef jinxes6
freopen("in.txt", "r", stdin);
#endif
ios::sync_with_stdio(false);
cin.tie(0);
int T = 1;
//cin >> T;
while (T --)
{
int n;
cin >> n;
init(n);
f(i, 1, n)
cin >> a[i];
sort (a + 1, a + n + 1);
f(i, 1, n)
{
if (a[i] > a[i - 1])
cnt[++ tot] = 1;
else
cnt[tot] ++;
}
sort (cnt + 1, cnt + tot + 1);
int p = 1, t = tot;
f(i, 1, tot)
{
if (cnt[p] != cnt[i])
{
ccnt[++ p] = 1;
cnt[p] = cnt[i];
}
else
ccnt[p] ++;
}
tot = p;
f(i, 1, n)
{
long long ans = t;
f(j, 1, tot)
{
if (i > n - cnt[j])
break;
long long temp = C(n - cnt[j], i, 0) * C(n, i, 1) % mod;
temp = temp * ccnt[j] % mod;
ans = (ans - temp + mod) % mod;
}
cout << ans << "\n";
}
}
return 0;
}