Description
N<=2500
Solution
听说暴力状压可以过?然而我常数不好只有90分
考虑普通的状压,F[i][s][j]表示当前填到第i行,第i行的状态为s,用了j个哲学♂家的方案数
我们把最后一维看做多项式,用x^j的系数表示答案
咦?模数是998244353哦,那我们是不是可以用NTT加速呢?
如果我们求出对于所有
wi
,答案的多项式的点值,我们就可以通过一次插值来还原出原多项式
求点值直接Dp就可以了,位移相当于乘上w的幂。但是这样会T
怎么办呢?发现这样Dp可以矩阵乘法加速,于是速度就起飞了~~
Code
#pragma GCC optimize(2)
#include <cstdio>
#include <cstring>
#include <algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
typedef long long ll;
const int N=2.5*1e3+5,mo=998244353;
int n,m,len,lg,ni,a[3][3],b[9],c[9][9],cnt[9],pw[4];
int t[N*6],W[N*6],ans[N*6],an[9];
struct matrix{
int a[9][9];
friend matrix operator *(matrix y,matrix z) {
matrix x;memset(x.a,0,sizeof(x.a));
fo(k,1,b[0])
fo(i,1,b[0])
fo(j,1,b[0])
(x.a[i][j]+=(ll)y.a[i][k]*z.a[k][j]%mo)%=mo;
return x;
}
}g,f;
bool ok(int x) {
if (a[1][0]&&(x&(x<<1))) return 0;
if (a[1][2]&&(x&(x>>1))) return 0;
return 1;
}
bool check(int x,int y) {
x=b[x],y=b[y];
if ((a[0][1]||a[2][1])&&(x&y)) return 0;
if ((a[0][0]||a[2][2])&&(x&(y<<1))) return 0;
if ((a[0][2]||a[2][0])&&(x&(y>>1))) return 0;
return 1;
}
int mi(int x,int y) {
int z=1;
for(;y;y/=2,x=(ll)x*x%mo)
if (y&1) z=(ll)z*x%mo;
return z;
}
void pwr(int y) {
for(y--;y;y/=2,f=f*f)
if (y&1) g=g*f;
}
void DFT(int *a,int flag) {
fo(i,0,len-1) {
int p=0;
for(int j=i,k=0;k<lg;k++,j/=2) p=(p<<1)+(j&1);
t[p]=a[i];
}
for(int m=2;m<=len;m*=2) {
int half=m/2,times=len/m;
for(int i=0;i<half;i++) {
int w=(flag>0)?W[i*times]:W[len-i*times];
for(int j=i;j<len;j+=m) {
int u=t[j],v=(ll)t[j+half]*w%mo;
t[j]=(u+v)%mo;
t[j+half]=(u-v+mo)%mo;
}
}
}
fo(i,0,len-1) a[i]=t[i];
if (flag==-1) fo(i,0,len-1) a[i]=(ll)a[i]*ni%mo;
}
void prepare() {
len=1,lg=0;
while (len<n*3) len*=2,lg++;
ni=mi(len,mo-2);
W[0]=1;W[1]=mi(3,(mo-1)/len);
fo(i,2,len) W[i]=(ll)W[i-1]*W[1]%mo;
}
int dp(int w) {
pw[0]=1;fo(i,1,3) pw[i]=(ll)pw[i-1]*w%mo;
memset(f.a,0,sizeof(f.a));
int res=0;
fo(i,1,b[0]) an[i]=pw[cnt[i]];
if (n==1) {
fo(i,1,b[0]) (res+=an[i])%=mo;
return res;
}
fo(j,1,b[0])
fo(k,1,c[j][0])
(f.a[j][c[j][k]]+=pw[cnt[c[j][k]]])%=mo;
memcpy(g.a,f.a,sizeof(f.a));
pwr(n-1);
fo(i,1,b[0])
fo(j,1,b[0])
(res+=(ll)an[i]*g.a[i][j]%mo)%=mo;
return res;
}
int main() {
scanf("%d%d",&n,&m);prepare();
fo(i,0,2) fo(j,0,2) scanf("%d",&a[i][j]);
fo(i,0,7) if (ok(i)) b[++b[0]]=i;
fo(i,1,b[0])
fo(j,0,2)
if (b[i]&(1<<j))
cnt[i]++;
fo(i,1,b[0])
fo(j,1,b[0])
if (check(i,j))
c[i][++c[i][0]]=j;
prepare();
fo(i,0,len-1) ans[i]=dp(W[i]);
DFT(ans,-1);
printf("%d\n",ans[m]);
}