LOJ传送门
题目描述
有一个长度很大的二进制串,初始时它的每一位都为
0
0
0。现在有
m
m
m 个操作,其中第
i
i
i 个操作是将这个二进制串的数值加上
2
a
i
(
0
≤
a
i
≤
n
)
{2}^{a_i}({0}\leq{a_i}\leq{n})
2ai(0≤ai≤n),或者说,给第 $a_i $位加上
1
1
1 并进位,我们称每次操作的代价是这次操作改变的位的数量。例如,当前的二进制串是 10111
时,如果给它加上
2
0
2^0
20,串就变成了 11000
,其中从低到高第 0,1,2,30,1,2,3 位发生了改变,那么这次操作代价为
4
4
4。
我们以一定概率执行这些操作:第 i i i 个操作有 p i p_i pi 的概率执行,否则不执行。请求出所有执行的操作的代价和的期望。
你只需要求出期望改变的位数在模 998244353 998244353 998244353 意义下的值。具体来说,如果你算出来的期望 E = P Q E = {\frac{P}{Q}} E=QP,其中 P , Q P,Q P,Q 互质,那么你只要输出 ( P Q − 1 ) ( m o d 998244353 ) ({P}{Q^{-1}})\pmod{998244353} (PQ−1)(mod998244353),其中 Q − 1 ( m o d 998244353 ) Q^{-1}\pmod{998244353} Q−1(mod998244353) 表示 Q Q Q 在 ( m o d 998244353 ) \pmod{998244353} (mod998244353) 意义下的逆元。
注意:执行完操作后,该串去除前导 0 0 0 后的长度可能大于 n n n。
输入输出格式
输入格式
第一行两个用空格分隔的正整数 n , m n,m n,m,分别表示 a i a_i ai 的范围和操作数,如上文所述。
接下来 m m m 行,每行三个正整数 a i , x i , y i a_i, x_i, y_i ai,xi,yi,其中 p i = x i y i p_i = {\frac{x_i}{y_i}} pi=yixi。
输出格式
仅一行,表示答案。
输入输出样例
样例输入 #1
4 4
0 1 2
0 1 2
0 1 2
0 1 2
样例输出 #1
187170819
样例输入 #2
233 6
1 166 233
2 233 666
3 166 266
4 233 266
5 233 666
6 166 233
样例输出 #2
56615945
解题分析
发现我们每次加入一个 1 1 1, 一定是消掉了前面的一段 1 1 1, 然后再添加上一个 1 1 1, 消掉的部分都贡献了两次。
所以设期望的操作次数为 x x x, 最后期望剩下 l e f lef lef位, 那么总期望为 2 x − l e f 2x-lef 2x−lef。
x x x很好算, 就是所有的概率加起来。 关键在于 l e f lef lef怎么算。
我们先不考虑进位的情况。 设 d p [ i ] [ j ] dp[i][j] dp[i][j]表示第 i i i位最终有 j j j次 + 1 +1 +1操作的概率, 那么我们可以构造一个多项式, 第 j j j位的系数表示 d p [ i ] [ j ] dp[i][j] dp[i][j], 那么实质上就是求 ( p 1 x + ( 1 − p 1 ) ) ∗ ( p 2 x + ( 1 − p 2 ) ) ∗ . . . (p_1x+(1-p_1))*(p_2x+(1-p_2))*... (p1x+(1−p1))∗(p2x+(1−p2))∗...这个式子, 分治 N T T NTT NTT搞一下就好了。
然后考虑进位, 设第
i
i
i位最多有
t
i
t_i
ti次操作, 那么大概是这样的:
d
p
[
i
+
1
]
[
j
]
=
∑
k
=
0
t
i
d
p
[
i
]
[
⌊
k
2
⌋
]
∗
d
p
[
i
+
1
]
[
j
−
⌊
k
2
⌋
]
dp[i+1][j]=\sum_{k=0}^{t_i}dp[i][\lfloor\frac{k}{2}\rfloor]*dp[i+1][j-\lfloor\frac{k}{2}\rfloor]
dp[i+1][j]=k=0∑tidp[i][⌊2k⌋]∗dp[i+1][j−⌊2k⌋]
同样
N
T
T
NTT
NTT即可。
因为一位最多造成的贡献是 ∑ i = 0 ∞ 2 − i = 2 \sum_{i=0}^\infin2^{-i}=2 ∑i=0∞2−i=2, 所以实际上就带一个最大为 2 2 2的常数。
代码如下:
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <cctype>
#include <algorithm>
#include <vector>
#include <cassert>
#define R register
#define IN inline
#define W while
#define gc getchar()
#define ll long long
#define MOD 998244353
#define G 3
#define Ginv 332748118
#define MX 400500
#define ls (now << 1)
#define rs (now << 1 | 1)
template <class T>
IN void in(T &x)
{
x = 0; R char c = gc;
for (; !isdigit(c); c = gc);
for (; isdigit(c); c = gc)
x = (x << 1) + (x << 3) + c - 48;
}
IN int fpow(R int base, R int tim)
{
int ret = 1;
W (tim)
{
if (tim & 1) ret = 1ll * ret * base % MOD;
base = 1ll * base * base % MOD, tim >>= 1;
}
return ret;
}
int n, m, ans;
int tim[MX], rev[MX];
std::vector <int> pin[MX], pout[MX], p[MX];
std::vector <int> vec[MX << 2];
std::vector <int> x, y;
IN void NTT(std::vector <int> &dat, R int len, R bool typ)
{
for (R int i = 1; i < len; ++i) if (rev[i] > i) std::swap(dat[i], dat[rev[i]]);
R int seg, now, cur, bd, step, foo, bar, deal, base;
for (seg = 1; seg < len; seg <<= 1)
{
base = fpow(typ ? G : Ginv, (MOD - 1) / (seg << 1)), step = seg << 1;
for (now = 0; now < len; now += step)
{
deal = 1, bd = now + seg;
for (cur = now; cur < bd; cur++, deal = 1ll * deal * base % MOD)
{
foo = dat[cur], bar = 1ll * dat[cur + seg] * deal % MOD;
dat[cur] = (foo + bar) % MOD, dat[cur + seg] = (foo - bar + MOD) % MOD;
}
}
}
if (typ) return;
int inv = fpow(len, MOD - 2);
for (R int i = 0; i < len; ++i) dat[i] = 1ll * dat[i] * inv % MOD;
}
void Get(R int now, R int lef, R int rig, R int id)
{
vec[now].clear();
if (lef == rig)
{
vec[now].resize(2);
vec[now][0] = pout[id][lef], vec[now][1] = pin[id][lef];
return;
}
int mid = lef + rig >> 1;
Get(ls, lef, mid, id), Get(rs, mid + 1, rig, id);
int seg = rig - lef + 1, len = 1, lg = 0;
W (len <= seg) len <<= 1, lg++;
vec[ls].resize(len + 1), vec[rs].resize(len + 1), vec[now].resize(len + 1);
for (R int i = 1; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg - 1);
NTT(vec[ls], len, 1), NTT(vec[rs], len, 1);
for (R int i = 0; i < len; ++i) vec[ls][i] = 1ll * vec[ls][i] * vec[rs][i] % MOD;
NTT(vec[ls], len, 0);
for (R int i = 0; i <= seg; ++i) vec[now][i] = vec[ls][i];
}
int main(void)
{
int foo, PIN, POUT, bar, a, b;
in(n), in(m);
for (R int i = 1; i <= m; ++i)
{
in(foo), in(a), in(b);
bar = fpow(b, MOD - 2);
PIN = 1ll * a * bar % MOD;
POUT = 1ll * ((b - a) % MOD + MOD) % MOD * bar % MOD;
++tim[foo], pin[foo].push_back(PIN), pout[foo].push_back(POUT);
(ans += PIN) %= MOD;
}
ans = ans * 2 % MOD;
for (R int i = 0; i <= n; ++i)
if (tim[i])
{
int sum = 0;
vec[1].clear();
Get(1, 0, tim[i] - 1, i);
for (R int j = 0; j <= tim[i]; ++j)
p[i].push_back(vec[1][j]);
}
else p[i].push_back(1);
p[n + 1].push_back(1);
for (R int i = 0; i <= n; ++i)
{
for (R int j = 1; j <= tim[i]; j += 2)
ans = (ans - p[i][j] + MOD) % MOD;
foo = tim[i + 1] + tim[i] / 2;
int len = 1, lg = 0;
W (len <= foo) lg++, len <<= 1;
x.resize(len + 1), y.resize(len + 1);
for (R int j = 0; j < len; ++j)
{
x[j] = y[j] = 0;
rev[j] = (rev[j >> 1] >> 1) | ((j & 1) << lg - 1);
}
for (R int j = 0; j <= tim[i]; ++j) (x[j / 2] += p[i][j]) %= MOD;
for (R int j = 0; j <= tim[i + 1]; ++j) y[j] = p[i + 1][j];
NTT(x, len, 1), NTT(y, len, 1);
for (R int j = 0; j < len; ++j) x[j] = 1ll * x[j] * y[j] % MOD;
NTT(x, len, 0);
tim[i + 1] += tim[i] / 2;
p[i + 1].resize(tim[i + 1] + 1);
for (R int j = 0; j <= foo; ++j) p[i + 1][j] = x[j];
}
int cur = n + 1;
W (tim[cur])
{
for (R int i = 1; i <= tim[cur]; i += 2)
ans = (ans - p[cur][i] + MOD) % MOD;
tim[cur + 1] = tim[cur] / 2;
p[cur + 1].resize(tim[cur + 1] + 1);
for (R int i = 0; i <= tim[cur]; ++i)
(p[cur + 1][i / 2] += p[cur][i]) %= MOD;
++cur;
}
printf("%d\n", ans);
}