首先记录
ai,j
a
i
,
j
表示单位
i
i
生命值为的概率,那么每次修改可以
O(m)
O
(
m
)
。
对于询问,先求出
exi=∑mij=1ai,j
e
x
i
=
∑
j
=
1
m
i
a
i
,
j
表示单位
i
i
存活的概率。那么我们只要对于每一个单位,求出除了该单位的剩下
k−1
k
−
1
个单位中,有
1..k−1
1..
k
−
1
个单位的存活的概率就解决了。
设
fi,j
f
i
,
j
表示考虑前
i
i
个单位,其中个单位存活的概率,转移一轮是
O(n2)
O
(
n
2
)
的,因为对于每个单位都要转移一轮,所以是
O(n3)
O
(
n
3
)
。
因为每轮转移都只有自己没有参与转移,考虑如何从01背包中删除一个元素。
添加一个元素
f′k=fk−1⋅exi+fk⋅(1−exi)
f
k
′
=
f
k
−
1
⋅
e
x
i
+
f
k
⋅
(
1
−
e
x
i
)
。
那么删除一个元素就是
fk=f′k−fk−1∗exi1−exi
f
k
=
f
k
′
−
f
k
−
1
∗
e
x
i
1
−
e
x
i
。
只需要先把所有的转移一遍,再分别删除每一个即可。注意处理
exi=1
e
x
i
=
1
的情况,总复杂度
O(Qm+Cn2)
O
(
Q
m
+
C
n
2
)
。
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 210
#define M 110
#define ll long long
using namespace std;
const int mod=998244353;
int n,c[N],Q,t[N];
ll a[N][M],f[N][N],g[N],ex[N],ans[N],inv[N];
ll ksm(ll a,ll b){ll r=1;for(;b;b>>=1){if(b&1)r=r*a%mod;a=a*a%mod;}return r;}
int read()
{
int x=0,f=1;char ch=getchar();
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=-1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x*f;
}
int main()
{
n=read();
for(int i=0;i<=n+1;i++)
inv[i]=ksm(i,mod-2);
for(int i=1;i<=n;i++)
c[i]=read(),a[i][c[i]]=1;
Q=read();
while(Q--)
{
int opt=read();
if(opt)
{
int m=read();
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
memset(ex,0,sizeof(ex));
for(int i=1;i<=m;i++)
t[i]=read();
for(int i=1;i<=m;i++)
ex[i]=(mod+1-a[t[i]][0])%mod;
f[0][0]=1;
for(int i=1;i<=m;i++)
for(int j=0;j<=i;j++)
f[i][j]=((j>0?f[i-1][j-1]*ex[i]:0)+f[i-1][j]*(mod+1-ex[i]))%mod;
for(int i=1;i<=m;i++)
ans[i]=0;
for(int i=1;i<=m;i++)
if(ex[i])
{
if(ex[i]==1)
for(int j=1;j<=m;j++)
g[j-1]=f[m][j];
else
{
ll tmp=ksm(mod+1-ex[i],mod-2);
g[0]=f[m][0]*tmp%mod;
for(int j=1;j<=m;j++)
g[j]=(f[m][j]-g[j-1]*ex[i]%mod+mod)*tmp%mod;
}
for(int j=1;j<=m;j++)
ans[i]=(ans[i]+ex[i]*g[j-1]%mod*inv[j])%mod;
}
for(int i=1;i<=m;i++)
printf("%lld ",ans[i]);
puts("");
}
else
{
int id=read();ll u=read(),v=read(),P=u*ksm(v,mod-2)%mod;
a[id][0]=(a[id][0]+P*a[id][1])%mod;
for(int i=1;i<=c[id];i++)
a[id][i]=(a[id][i]*(mod+1-P)+a[id][i+1]*P)%mod;
}
}
for(int i=1;i<=n;i++)
ans[i]=0;
for(int i=1;i<=n;i++)
for(int j=1;j<=c[i];j++)
ans[i]=(ans[i]+a[i][j]*j)%mod;
for(int i=1;i<=n;i++)
printf("%lld ",ans[i]);
puts("");
return 0;
}