题目链接:【LOJ 565】mathematican 的二进制
题目大意:有一个 n n 位的二进制数, 个操作。第 i i 个操作是将这个二进制串的数值加上 ,有 pi p i 的几率被执行。每次操作的代价是这次操作改变的位的数量。求代价的期望值 mod998244353 mod 998244353 的结果。 n,m≤2×105 n , m ≤ 2 × 10 5
我们发现:
- 最终的答案与操作顺序无关,只与哪些操作被执行过有关。
- 因为每次进位总会让 1 1 的总个数减少 ,总代价就是所有被执行的操作的总次数的两倍减去最终剩下的数中 1 1 的个数。即:。
于是可以列出递推式: f(i,j) f ( i , j ) 表示从后往前第 i i 位总共被改变 次的概率,那么我们有两种转移:
- 进位: f(i−1,j)→f(i,⌊j2⌋) f ( i − 1 , j ) → f ( i , ⌊ j 2 ⌋ ) 。
- 操作:对于第 i i 位每个概率为 的操作, (1−p)⋅f(i,j)+p⋅f(i−1,j)→f(i,j) ( 1 − p ) ⋅ f ( i , j ) + p ⋅ f ( i − 1 , j ) → f ( i , j ) 。
发现进位可以直接转移,操作可以用分治 NTT N T T 转移。时间复杂度 O(mlog2m) O ( m log 2 m ) 。
注意:二进制数的最大值不是
2n
2
n
,而是
m2n
m
2
n
。
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const int maxn = 200025;
const int maxm = 1 << 19 | 5;
int n, m, r[maxm], cur;
ll temp[2][maxm];
vector<ll> f, g, v[maxn], p[maxm];
ll mpow(ll a, ll b, ll c) {
if ((b %= mod - 1) < 0) {
b += mod - 1;
}
ll d = 1;
for (; b; b >>= 1, a = a * a % c) {
if (b & 1) {
d = d * a % c;
}
}
return d;
}
void ntt(ll *a, int n, int opt) {
for (int i = 0; i < n; i++) {
if (i < r[i]) {
swap(a[i], a[r[i]]);
}
}
for (int k = 1; k < n; k <<= 1) {
ll v = mpow(3, (mod - 1) / (k << 1) * opt, mod);
for (int i = 0; i < n; i += k << 1) {
ll w = 1;
for (int j = i; j < i + k; j++, w = w * v % mod) {
ll x = a[j], y = w * a[j + k] % mod;
a[j] = (x + y) % mod, a[j + k] = (x - y) % mod;
}
}
}
if (opt == -1) {
ll v = -(mod - 1) / n;
for (int i = 0; i < n; i++) {
a[i] = v * a[i] % mod;
}
}
}
void mult(const vector<ll> &f, const vector<ll> &g, vector<ll> &h) {
int lim, bit = 0, len = f.size() + g.size() - 1;
for (lim = 1; lim < len; lim <<= 1) bit++;
for (int i = 0; i < lim; i++) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
for (int i = 0; i < f.size(); i++) {
temp[0][i] = f[i];
}
for (int i = f.size(); i < lim; i++) {
temp[0][i] = 0;
}
for (int i = 0; i < g.size(); i++) {
temp[1][i] = g[i];
}
for (int i = g.size(); i < lim; i++) {
temp[1][i] = 0;
}
ntt(temp[0], lim, 1);
ntt(temp[1], lim, 1);
for (int i = 0; i < lim; i++) {
temp[0][i] = temp[0][i] * temp[1][i] % mod;
}
ntt(temp[0], lim, -1);
h.clear();
for (int i = 0; i < len; i++) {
h.push_back(temp[0][i]);
}
}
void solve(const vector<ll> &v, int l, int r, vector<ll> &u) {
if (l == r) {
u.clear();
u.push_back(1 - v[l]);
u.push_back(v[l]);
return;
}
int x = cur++, y = cur++, md = (l + r) >> 1;
solve(v, l, md, p[x]);
solve(v, md + 1, r, p[y]);
mult(p[x], p[y], u);
}
void calc(const vector<ll> &v, vector<ll> &u) {
if (!v.size()) {
u.clear();
u.push_back(1);
} else {
solve(v, 0, v.size() - 1, u);
}
}
int main() {
scanf("%d %d", &n, &m);
ll sum = 0, a, p, q;
for (int i = 1; i <= m; i++) {
scanf("%lld %lld %lld", &a, &p, &q);
p = p * mpow(q, mod - 2, mod) % mod;
v[a].push_back(p), sum = (sum + p) % mod;
}
f.push_back(1);
sum = 2 * sum % mod;
for (int i = 0; i <= n + 20; i++) {
g.resize((f.size() + 1) >> 1);
fill(g.begin(), g.end(), 0);
for (int i = 0; i < f.size(); i++) {
g[i >> 1] += f[i];
}
calc(v[i], f);
mult(f, g, f);
for (int i = 1; i < f.size(); i += 2) {
sum = (sum - f[i]) % mod;
}
}
printf("%lld\n", (sum + mod) % mod);
return 0;
}