D - Three Days Ago (atcoder.jp)
思路:
抽象一下就是求每种数字的个数都是偶数的字串个数。
一共就10个数字,用二进制的每一位表示一个数字的奇偶性,开个桶记录即可。对于每一个位置,和前面每一个奇偶性完全相同的位置都能组合成一个成立子串,结果加上cnt再让cnt加上该位置的贡献。
#include <bits/stdc++.h>
using namespace std;
#define io ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
typedef long long ll;
#define i64 __int64
//#define int ll
#define pb push_back
#define eb emplace_back
#define m_p make_pair
#define mod 998244353
#define mem(a,b) memset(a,b,sizeof a)
#define pii pair<int,int>
#define fi first
#define se second
#define inf 0x3f3f3f3f
const int N = 3e5 + 50;
//__builtin_ctzll(x);后导0的个数
//__builtin_popcount计算二进制中1的个数
int cnt[1025];
void work() {
int flag = 0;
ll ans = 0;
string s;
cin >> s;
cnt[0] = 1;
for (int i = 0; i < s.size(); ++i) {
flag ^= (1 << (s[i] - '0'));
ans += cnt[flag];
cnt[flag]++;
}
cout << ans << '\n';
}
signed main() {
io;
work();
return 0;
}
思路:
用f[i]表示第k位大于等于i的可能种类数,f[i] -f[i - 1]就是第k位大于i的可能种类数。期望就是
种类数*数值 / 所有可能,所有可能就是m^cnt。
求f[i]的时候枚举第k位之前可以操作的数的个数,令这些数严格小于i,剩下的数大于等于i,即可保证第k位大于等于i。
注意取模,会爆ll。
#include <bits/stdc++.h>
using namespace std;
#define io ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
typedef long long ll;
#define i64 __int64
#define int ll
#define pb push_back
#define eb emplace_back
#define m_p make_pair
#define mod 998244353
#define mem(a,b) memset(a,b,sizeof a)
#define pii pair<int,int>
#define fi first
#define se second
#define inf 0x3f3f3f3f
const int N = 2e3 + 50;
//__builtin_ctzll(x);后导0的个数
//__builtin_popcount计算二进制中1的个数
int f[N], cnt, sum[N], fac[N], inv[N];
ll qp(ll a, ll b) {
ll ans = 1;
a %= mod;
while (b) {
if (b & 1)
ans = ans * a % mod;
a = a * a % mod;
b >>= 1;
}
return ans;
}
ll C(int a, int b) {
if (a < b or a < 0 or b < 0)
return 0;
return 1ll * fac[a] * inv[b] % mod * inv[a - b] % mod;
}
void init() {
fac[0] = fac[1] = inv[0] = inv[1] = 1;
for (int i = 2; i < N; ++i) {
fac[i] = 1ll * fac[i - 1] * i % mod; //阶乘
inv[i] = 1ll * inv[mod % i] * (mod - mod / i) % mod; //逆元
}
for (int i = 2; i < N; ++i) {
inv[i] = 1ll * inv[i] * inv[i - 1] % mod;
}
}
void work() {
init();
int n, m, k;
cin >> n >> m >> k;
for (int i = 1; i <= n; ++i) {
int x;
cin >> x;
if (x == 0)
cnt++;
else
sum[x]++;
}
for (int i = 1; i <= m; ++i) {
sum[i] += sum[i - 1];
}
for (int i = 1; i <= m; ++i) {
if (sum[i - 1] >= k)
break;
for (int j = 0; j < min(cnt + 1, k - sum[i - 1]); ++j) { //枚举k之前的数
f[i] += C(cnt, j) % mod * qp(i - 1, j) % mod * qp(m - i + 1, cnt - j) % mod;
f[i] %= mod;
}
}
ll res = 0;
for (int i = 1; i <= m; ++i) {
res += (f[i] - f[i + 1]) % mod * i % mod;
res %= mod;
}
res *= qp(qp(m, cnt) % mod, mod - 2) % mod;
res %= mod;
res = (res + mod) % mod;
cout << res << '\n';
}
signed main() {
io;
work();
return 0;
}