听说这玩意叫 PGF?
方便起见,令 p i = p i ∑ j p j p_i=\frac{p_i}{\sum_jp_j} pi=∑jpjpi。
设
F
i
(
x
)
F_i(x)
Fi(x) 表示对于第
i
i
i 个开关而言,对其进行
k
k
k 次操作之后,它达到目标状态的概率的 EGF(其实文字不好表达
F
i
(
x
)
F_i(x)
Fi(x) 的意思,因为它只是一个辅助生成函数。看下去就能理解
F
i
(
x
)
F_i(x)
Fi(x) 的用途),那么有:
F
i
(
x
)
=
∑
k
≥
0
[
k
m
o
d
2
=
s
i
]
p
i
k
x
k
k
!
=
e
p
i
x
+
(
−
1
)
s
i
e
−
p
i
x
2
F_i(x)=\sum_{k\geq 0}[k\bmod 2=s_i]\frac{p_i^kx^k}{k!}=\frac{e^{p_ix}+(-1)^{s_i}e^{-p_ix}}{2}
Fi(x)=k≥0∑[kmod2=si]k!pikxk=2epix+(−1)sie−pix
然后设
G
E
(
x
)
G_E(x)
GE(x) 表示
k
k
k 次操作后全体恰好达到目标状态的概率的 EGF(这里的释义才是准确的),有:
G
E
(
x
)
=
∏
i
=
1
n
F
i
(
x
)
G_E(x)=\prod_{i=1}^nF_i(x)
GE(x)=i=1∏nFi(x)
设
R
E
(
x
)
R_E(x)
RE(x) 表示
k
k
k 次操作后恰好达到全
0
0
0 状态的概率的 EGF,类似地有:
R
E
(
x
)
=
∏
i
=
1
n
(
∑
k
≥
0
,
2
∣
k
p
i
k
x
k
k
!
)
=
∏
i
=
1
n
(
e
p
i
x
+
e
−
p
i
x
2
)
R_E(x)=\prod_{i=1}^n\left(\sum_{k\geq 0,2|k}\frac{p_i^kx^k}{k!}\right)=\prod_{i=1}^n\left(\frac{e^{p_ix}+e^{-p_ix}}{2}\right)
RE(x)=i=1∏n⎝⎛k≥0,2∣k∑k!pikxk⎠⎞=i=1∏n(2epix+e−pix)
设
G
O
(
x
)
G_O(x)
GO(x) 表示
G
E
(
x
)
G_E(x)
GE(x) 转回 OGF 后的函数,
R
O
(
x
)
R_O(x)
RO(x) 同理。
设
H
(
x
)
H(x)
H(x) 表示
k
k
k 次操作后全体恰好达到目标状态、且之前都没有达到过的概率的 OGF。那么:
H
(
x
)
R
O
(
x
)
=
G
O
(
x
)
H(x)R_O(x)=G_O(x)
H(x)RO(x)=GO(x)
我们要求的是
H
′
(
1
)
H'(1)
H′(1)。
但事实上并没有说的那么简单,一个最主要的问题就是 R , G R,G R,G 都是无限项的,所以我们先要求出 R , G R,G R,G 的封闭形式。
观察
G
E
(
x
)
G_E(x)
GE(x):
G
E
(
x
)
=
∏
i
=
1
n
e
p
i
x
+
(
−
1
)
s
i
e
−
p
i
x
2
G_E(x)=\prod_{i=1}^n\frac{e^{p_ix}+(-1)^{s_i}e^{-p_ix}}{2}
GE(x)=i=1∏n2epix+(−1)sie−pix
发现它展开后封闭形式是
∑
k
g
k
e
k
x
\sum_k g_ke^{kx}
∑kgkekx,这种形式转 OGF 是方便的:
G
O
(
x
)
=
∑
k
g
k
1
−
k
x
G_O(x)=\sum_k \frac{g_k}{1-kx}
GO(x)=∑k1−kxgk。
那么我们只需求出所有的 k k k 和对应的 a k a_k ak 即可。但看起来 k k k 的个数貌似是 O ( 2 n ) O(2^n) O(2n) 级别的。
事实上,可以发现 k k k 一定是 k ′ ∑ j p j \frac{k'}{\sum_{j} p_j} ∑jpjk′ 的形式,其中 k ′ ∈ [ − ∑ j p j , ∑ j p j ] k'\in[-\sum_j p_j,\sum_j p_j] k′∈[−∑jpj,∑jpj]。那么 k k k 只有 O ( ∑ j p j ) O(\sum_jp_j) O(∑jpj) 种。那么我们可以通过一个 O ( n ∑ j p j ) O(n\sum_jp_j) O(n∑jpj) 的 DP 求出每个 g k g_k gk。
同理我们也可以求出
R
O
(
x
)
=
∑
k
r
k
1
−
k
x
R_O(x)=\sum_k\frac{r_k}{1-kx}
RO(x)=∑k1−kxrk。那么:
H
(
x
)
=
∑
k
g
k
1
−
k
x
∑
k
r
k
1
−
k
x
=
g
1
+
∑
k
≠
1
g
k
(
1
−
x
)
1
−
k
x
r
1
+
∑
k
≠
1
r
k
(
1
−
x
)
1
−
k
x
\begin{aligned} H(x)&=\frac{\sum_{k}\frac{g_k}{1-kx}}{\sum_k\frac{r_k}{1-kx}}\\ &=\frac{g_1+\sum_{k\neq 1}\frac{g_k(1-x)}{1-kx}}{r_1+\sum_{k\neq 1}\frac{r_k(1-x)}{1-kx}} \end{aligned}
H(x)=∑k1−kxrk∑k1−kxgk=r1+∑k=11−kxrk(1−x)g1+∑k=11−kxgk(1−x)
令
A
(
x
)
A(x)
A(x) 为分母,
B
(
x
)
B(x)
B(x) 为分子。有:
H
(
x
)
′
=
A
(
x
)
′
B
(
x
)
−
A
(
x
)
B
(
x
)
′
B
(
x
)
2
H(x)'=\frac{A(x)'B(x)-A(x)B(x)'}{B(x)^2}
H(x)′=B(x)2A(x)′B(x)−A(x)B(x)′
所以我们只需求出
A
(
1
)
,
A
(
1
)
′
,
B
(
1
)
,
B
(
1
)
′
A(1),A(1)',B(1),B(1)'
A(1),A(1)′,B(1),B(1)′ 即可。这里以
A
(
x
)
′
A(x)'
A(x)′ 为例:
A
(
x
)
′
=
[
g
1
+
∑
k
≠
1
g
k
(
1
−
x
)
1
−
k
x
]
′
=
∑
k
≠
1
g
k
[
1
−
x
1
−
k
x
]
′
=
∑
k
≠
1
g
k
k
−
1
(
1
−
k
x
)
2
A
(
1
)
′
=
∑
k
≠
1
g
k
k
−
1
(
1
−
k
)
2
=
∑
k
≠
1
g
k
k
−
1
\begin{aligned} A(x)'&=\left[g_1+\sum_{k\neq 1}\frac{g_k(1-x)}{1-kx}\right]'\\ &=\sum_{k\neq 1}g_k\left[\frac{1-x}{1-kx}\right]'\\ &=\sum_{k\neq 1}g_k\frac{k-1}{(1-kx)^2}\\ A(1)'&=\sum_{k\neq 1}g_k\frac{k-1}{(1-k)^2}\\ &=\sum_{k\neq 1}\frac{g_k}{k-1} \end{aligned}
A(x)′A(1)′=⎣⎡g1+k=1∑1−kxgk(1−x)⎦⎤′=k=1∑gk[1−kx1−x]′=k=1∑gk(1−kx)2k−1=k=1∑gk(1−k)2k−1=k=1∑k−1gk
总时间复杂度
O
(
n
∑
j
p
j
)
O(n\sum_jp_j)
O(n∑jpj)。
#include<bits/stdc++.h>
#define N 110
#define SP 50010
using namespace std;
namespace modular
{
const int mod=998244353,inv2=(mod+1)>>1;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
inline void Add(int &x,int y){x=x+y>=mod?x+y-mod:x+y;}
inline void Dec(int &x,int y){x=x-y<0?x-y+mod:x-y;}
inline void Mul(int &x,int y){x=1ll*x*y%mod;}
inline int poww(int a,int b){int ans=1;for(;b;Mul(a,a),b>>=1)if(b&1)Mul(ans,a);return ans;}
}using namespace modular;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
int n,p[N],sp,sp2,invsp;
int g[SP<<1],r[SP<<1];
bool s[N];
void trans(int *f,int p,bool neg)
{
static int ff[SP<<1];
int c1=inv2,c2=(neg?dec(0,inv2):inv2);
for(int i=0;i+p<=sp2;i++) Add(ff[i+p],mul(c1,f[i]));
for(int i=sp2;i-p>=0;i--) Add(ff[i-p],mul(c2,f[i]));
for(int i=0;i<=sp2;i++) f[i]=ff[i],ff[i]=0;
}
int calc(int *f)
{
return f[sp2];
}
int calcd(int *f)
{
int ans=0;
for(int i=-sp;i<=sp;i++)
if(i!=sp) Add(ans,mul(f[i+sp],poww(dec(mul((i+mod)%mod,invsp),1),mod-2)));
return ans;
}
int main()
{
n=read();
for(int i=1;i<=n;i++) s[i]=read();
for(int i=1;i<=n;i++) Add(sp,p[i]=read());
sp2=sp<<1,invsp=poww(sp,mod-2);
g[sp]=r[sp]=1;
for(int i=1;i<=n;i++) trans(g,p[i],s[i]);
for(int i=1;i<=n;i++) trans(r,p[i],0);
int A=g[sp2],Ad=calcd(g),B=r[sp2],Bd=calcd(r);
printf("%d\n",mul(dec(mul(Ad,B),mul(A,Bd)),poww(mul(B,B),mod-2)));
return 0;
}