题目链接
如果我们知道
1
1
1到
k
−
1
k-1
k−1中每个数的出现次数分别为
c
1
,
c
2
,
.
.
.
,
c
k
−
1
c_1,c_2,...,c_{k-1}
c1,c2,...,ck−1,用排列组合相关知识就可以算出这样的数就有
(
∑
i
=
0
k
−
1
c
i
)
!
∏
j
=
0
k
−
1
c
i
!
\frac{(\sum_{i=0}^{k-1}c_i)!}{\prod_{j=0}^{k-1}c_i!}
∏j=0k−1ci!(∑i=0k−1ci)!个。
然后我们构造多项式
f
i
(
n
)
=
∑
j
=
0
n
g
i
,
j
j
!
x
j
f_i(n)=\sum_{j=0}^n\frac{g_{i,j}}{j!}x^j
fi(n)=∑j=0nj!gi,jxj,我们会发现初始答案就是把
k
−
1
k-1
k−1个多项式乘起来后每项再乘上
i
!
i!
i!后求和,利用NTT可以在
O
(
n
k
2
l
o
g
(
n
k
)
)
O(nk^2log(nk))
O(nk2log(nk))的时间复杂度内算出答案。
然后我们考虑怎么修改。
我们会发现每次修改相当于在
f
i
f_i
fi上加一个只有一项不为
0
0
0的多项式后再和其它多项式乘起来。
多项式加法也可以在点值表达式中计算,所以我们考虑这个多项式
D
F
T
DFT
DFT后会变成什么样:
y
k
=
∑
i
=
0
n
−
1
a
i
ω
n
i
k
y_k=\sum_{i=0}^{n-1}a_i\omega_n^{ik}
yk=i=0∑n−1aiωnik
由于只有一项
a
i
a_i
ai是非零的,设这一项为
a
j
a_j
aj,显然
a
j
=
±
1
j
!
a_j=\pm\frac1{j!}
aj=±j!1
所以
y
k
=
±
ω
n
j
k
j
!
y_k=\pm\frac{\omega_n^{jk}}{j!}
yk=±j!ωnjk。
怎么把这个多项式和
f
i
f_i
fi加起来之后再快速算出答案?
显然不能加了之后再重新把这些多项式乘起来,因为这样每次修改的时间复杂度是
O
(
n
k
l
o
g
(
n
k
)
)
O(nklog(nk))
O(nklog(nk))的,显然会TLE。
所以我们可以把之前得到的答案的多项式除掉
f
i
f_i
fi然后再把更新后的
f
i
f_i
fi乘进来。
具体怎么除?
我们可以保留答案的多项式点值表达式的形式,然后逐项除。
如果遇到0怎么办?
我们可以额外记录答案多项式中每位乘0的个数,遇到每项乘0的时候++,而不是在多项式中直接乘,除的时候同理。
这样每次修改的时间复杂度就是
O
(
n
k
)
O(nk)
O(nk),总时间复杂度为
O
(
n
k
2
l
o
g
(
n
k
)
+
m
n
k
)
O(nk^2log(nk)+mnk)
O(nk2log(nk)+mnk)
代码:
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=14010,P=1000010,mod=786433;
int k,n,m,Ans,limit=1,fac[P],invfac[P],inv[P],pw[P],d[15][N*20],f[N*20],f2[N*20],ans[N*20],rev[N*20];
char a[15][N];
int Add(int a,int b){
return a+b>=mod?a+b-mod:a+b;
}
int Minus(int a,int b){
return a-b<0?a-b+mod:a-b;
}
int qpow(int x,int n){
int ret=1;
while(n){
if(n&1)
ret=1ll*ret*x%mod;
x=1ll*x*x%mod;
n>>=1;
}
return ret;
}
const int g=10,gi=qpow(g,mod-2);
void NTT(int *a,int limit,int type){
for(int i=0;i<limit;i++)
if(i<rev[i])
swap(a[i],a[rev[i]]);
for(int mid=1;mid<limit;mid<<=1){
int wn=qpow(type==1?g:gi,(mod-1)/(mid<<1));
for(int i=0;i<limit;i+=(mid<<1)){
int w=1;
for(int j=0;j<mid;j++,w=1ll*w*wn%mod){
int x=a[i+j],y=1ll*w*a[i+j+mid]%mod;
a[i+j]=Add(x,y);
a[i+j+mid]=Minus(x,y);
}
}
}
if(type==-1)
for(int i=0;i<limit;i++)
a[i]=1ll*a[i]*inv[limit]%mod;
return;
}
void init(int n){
int l=0;
limit=1;
while(limit<=n){
l++;
limit<<=1;
}
for(int i=0;i<limit;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
fac[0]=1;
for(int i=1;i<=mod-1;i++)
fac[i]=1ll*fac[i-1]*i%mod;
invfac[mod-1]=qpow(fac[mod-1],mod-2);
for(int i=mod-2;i>=0;i--)
invfac[i]=1ll*invfac[i+1]*(i+1)%mod;
pw[0]=1;
for(int i=1;i<mod-1;i++)
pw[i]=1ll*pw[i-1]*g%mod;
inv[1]=1;
for(int i=2;i<=mod-1;i++)
inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
return;
}
int main(){
int T;
scanf("%d",&T);
while(T--){
scanf("%d%d%d",&k,&n,&m);
for(int i=1;i<k;i++)
scanf("%s",a[i]);
init((k-1)*n);
for(int i=0;i<limit;i++)
f[i]=1,f2[i]=0;
for(int i=1;i<k;i++){
for(int j=0;j<=n;j++)
d[i][j]=1ll*(a[i][j]-='0')*invfac[j]%mod;
for(int j=n+1;j<limit;j++)
d[i][j]=0;
NTT(d[i],limit,1);
for(int j=0;j<limit;j++)
if(d[i][j])
f[j]=1ll*f[j]*d[i][j]%mod;
else
f2[j]++;
}
for(int i=0;i<limit;i++)
if(!f2[i])
ans[i]=f[i];
else ans[i]=0;
while(m--){
int x,y;
scanf("%d%d",&x,&y);
a[x][y]^=1;
for(int i=0;i<limit;i++)
if(d[x][i])
f[i]=1ll*f[i]*inv[d[x][i]]%mod;
else f2[i]--;
int s1=pw[((mod-1)/limit*y)%(mod-1)],s2=invfac[y];
if(!a[x][y])
s2=Minus(0,s2);
for(int i=0;i<limit;i++){
d[x][i]=Add(d[x][i],s2);
s2=1ll*s1*s2%mod;
}
for(int i=0;i<limit;i++){
if(d[x][i])
f[i]=1ll*f[i]*d[x][i]%mod;
else f2[i]++;
if(!f2[i])
ans[i]=Add(ans[i],f[i]);
}
}
NTT(ans,limit,-1);
Ans=0;
for(int i=1;i<limit;i++)
Ans=(Ans+1ll*ans[i]*fac[i])%mod;
printf("%d\n",Ans);
}
return 0;
}