题目大意:给你一个n*m的棋盘,你需要把一些格子染成黑色,使得有恰好k个黑色联通块,对998244353取模, n≤3,k,m≤5e4 n ≤ 3 , k , m ≤ 5 e 4 。
题解:考虑dp,dp[i][j][S]表示第i列,第i列的黑格子集合是S并且有j个联通块(注意当n=3的时候第一、三行同是黑格子的时候有是否联通两种情况),这样|S|<=9(事实上有几种状态相同可以去掉,即|S|<=7)。这样直接dp可以做到O(mk|S|),显然T的飞起。考虑优化,如果把f[i][S]看成一个多项式,其中x^j的系数是前i列最后一列状态S,恰有j个联通块的方案数。这样发现这个多项式的转移就是f[i-1][T]乘以x或者1或者1/x然后加起来。直接这么做没有优化,考虑插值,假设代入了一个x0,那么可以用矩阵乘法得出f[m][S](x0)。但直接这么做还不能够优化复杂度(除非你写一个多点插值),但其实取x0为一些单位根即可,最后NTT回来,这样复杂度就降下来了。
代码:
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define mod 998244353
#define lint long long
#define N 1000010
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
int r[N];
struct matrix{
int n,m,v[15][15];
matrix(int _n=0,int _m=0)
{
n=_n,m=_m;
for(int i=1;i<=n;i++)
memset(v[i],0,sizeof(int)*(m+1));
}
inline int init(int _n)
{
n=m=_n;
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
v[i][j]=(i==j);
return 0;
}
inline matrix operator=(const matrix &b)
{
n=b.n,m=b.m;
for(int i=1;i<=n;i++)
memcpy(v[i],b.v[i],sizeof(int)*(m+1));
return *this;
}
inline matrix operator*(const matrix &b)
{
matrix &a=*this,c(a.n,b.m);
for(int i=1;i<=c.n;i++)
for(int k=1;k<=a.m;k++)
for(int j=1;j<=c.m;j++)
(c.v[i][j]+=(lint)a.v[i][k]*b.v[k][j]%mod)%=mod;
return c;
}
inline matrix operator*=(const matrix &b)
{ return (*this)=(*this)*b; }
};
inline matrix fast_pow(matrix x,int k)
{
matrix ans(x.n);
for(ans.init(x.n);k;k>>=1,x*=x)
if(k&1) ans*=x;
return ans;
}
inline int fast_pow(int x,int k,int ans=1)
{
for(;k;k>>=1,x=(lint)x*x%mod)
if(k&1) ans=(lint)ans*x%mod;
return ans;
}
inline int NTT(int *a,int n,int sgn,int p=mod)
{
for(int i=0;i<n;i++)
if(i<r[i]) swap(a[i],a[r[i]]);
for(int i=2,g=3;i<=n;i<<=1)
{
int wn=fast_pow(g,(sgn>0)?(p-1)/i:p-1-(p-1)/i);
for(int j=0,t=i>>1;j<n;j+=i)
for(int k=0,w=1;k<t;k++,w=(lint)w*wn%p)
{
int x=a[j+k],y=(lint)w*a[j+k+t]%p;
a[j+k]=(x+y)%p,a[j+k+t]=(x-y+p)%p;
}
}
return 0;
}
int m,c[N];
inline int calc1(int x)
{
matrix A(2,1),B(2,2);
A.v[1][1]=1,A.v[2][1]=B.v[2][1]=x;
B.v[1][1]=B.v[1][2]=B.v[2][2]=1;
A=fast_pow(B,m-1)*A;
return (A.v[1][1]+A.v[2][1])%mod;
}
inline int calc2(int x)
{
matrix A(4,1),B(4,4);
for(int i=0;i<4;i++)
for(int j=0;j<4;j++)
if(i&&!(i&j)) B.v[i+1][j+1]=x;
else B.v[i+1][j+1]=1;
A.v[1][1]=1,A.v[2][1]=A.v[3][1]=A.v[4][1]=x;
A=fast_pow(B,m-1)*A;int ans=0;
for(int i=1;i<=4;i++) (ans+=A.v[i][1])%=mod;
return ans;
}
inline int calc3(int x)
{
matrix A(9,1),B(9,9);
for(int i=0;i<8;i++)
for(int j=0;j<8;j++)
if(i&&!(i&j)) B.v[i+1][j+1]=x;
else B.v[i+1][j+1]=1;
for(int i=1;i<=9;i++) B.v[6][i]=0;
B.v[6][6]=1,B.v[6][8]=1;
int xinv=fast_pow(x,mod-2);
B.v[9][1]=B.v[9][3]=(lint)x*x%mod,
B.v[9][2]=B.v[9][4]=B.v[9][5]=B.v[9][7]=x,
B.v[9][6]=B.v[9][8]=0,B.v[9][9]=1;
B.v[1][9]=B.v[2][9]=B.v[4][9]=B.v[5][9]=B.v[7][9]=1,
B.v[3][9]=x,B.v[6][9]=0,B.v[8][9]=xinv,B.v[9][9]=1;
for(int i=2;i<=8;i++) A.v[i][1]=x;A.v[1][1]=1;
A.v[6][1]=0,A.v[9][1]=(lint)x*x%mod;
A=fast_pow(B,m-1)*A;int ans=0;
for(int i=1;i<=9;i++) (ans+=A.v[i][1])%=mod;
return ans;
}
int main()
{
// freopen("chess.in","r",stdin);
// freopen("chess.out","w",stdout);
int n,k;scanf("%d%d%d",&n,&m,&k);
if(n==1)
{
int L,wn;for(n=1,L=0;n<=(m-1)/2+1;n<<=1,L++);
for(int i=1;i<=n;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(L-1));
wn=fast_pow(3,(mod-1)/n);
for(int i=0,w=1;i<n;i++,w=(lint)w*wn%mod) c[i]=calc1(w);
NTT(c,n,-1);int ninv=fast_pow(n,mod-2);
for(int i=0;i<n;i++) c[i]=(lint)c[i]*ninv%mod;
return !printf("%d\n",c[k]);
}
if(n==2)
{
int L,wn;for(n=1,L=0;n<=m;n<<=1,L++);
for(int i=1;i<=n;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(L-1));
wn=fast_pow(3,(mod-1)/n);
for(int i=0,w=1;i<n;i++,w=(lint)w*wn%mod) c[i]=calc2(w);
NTT(c,n,-1);int ninv=fast_pow(n,mod-2);
for(int i=0;i<n;i++) c[i]=(lint)c[i]*ninv%mod;
return !printf("%d\n",c[k]);
}
if(n==3)
{
int L,wn;for(n=1,L=0;n<=(3*m-1)/2+1;n<<=1,L++);
for(int i=1;i<=n;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(L-1));
wn=fast_pow(3,(mod-1)/n);
for(int i=0,w=1;i<n;i++,w=(lint)w*wn%mod) c[i]=calc3(w);
NTT(c,n,-1);int ninv=fast_pow(n,mod-2);
for(int i=0;i<n;i++) c[i]=(lint)c[i]*ninv%mod;
return !printf("%d\n",c[k]);
}
return 0;
}