题目概述
有一个n位的准考证号码(每位都是0~9),给出一个m位的不吉利数字,求不包含不吉利数字的准考证号码(不吉利数字不是这个准考证号码的子串)的个数。
解题报告
像这种问题基本做法要么是递推,要么是组合数学。思考后,发现无论正着思考还是反着思考都很难直接推出公式,于是考虑递推:
f[i][j]表示前i位,匹配到不吉利数字前j位且不出现不吉利数字的方案数。
然后来推递推式,是这样的吗?
j>0:f[i][j]=f[i-1][j-1]+f[i-1][j]*9
j==0:f[i][j]=Σf[i-1][0~m-1]*9
其实是错的,因为当第i位选不同数字时,f[i][j]还可以由其他f[i-1][k]推过来,其中k满足k>=j且k-j+1~k与1~j相同(比如k==j时)。然后我们会惊奇的发现这就是KMP的失配过程,所以我们可以用KMP来处理出递推式。
不过这个递推无论时间还是空间都是不行的,所以我们想到了矩阵乘法,设矩阵乘法的变换矩阵是T,那么每次递推形如:
[fi,0fi,1⋯fi,m−1fi,m]∗T=[fi+1,0fi+1,1⋯fi+1,m−1fi+1,m]
如何构造T是关键,根据这个递推,我们可以得知T是0~m*0~m的一个矩阵,并且T[x][y]表示
fi,x
对
fi+1,y
造成的影响。那么T的构造实际上就是上面所说的用KMP处理出递推式:先枚举i表示上次的第i位,然后枚举这次选的数now,用KMP的失配函数处理出最靠近的匹配点j,则i必定是满足j的一个状态,累加T[i][j]。
最后利用快速幂求出f[n],Σf[n][0~m-1]就是答案。
示例程序
#include<cstdio>
#include<cstring>
using namespace std;
const int maxm=20;
//=================================================
int n,m,MOD,sum,s[maxm+5],fa[maxm+5];
//=================================================
struct Matrix
{
int r,c,num[maxm+5][maxm+5];
void clear(int R,int C) {r=R;c=C;memset(num,0,sizeof(num));}
};
Matrix ans,T,c;
Matrix operator * (const Matrix &a,const Matrix &b) //矩阵乘法
{
c.clear(a.r,b.c);
for (int i=0;i<a.r;i++)
for (int j=0;j<b.c;j++)
for (int k=0;k<a.c;k++)
c.num[i][j]=(c.num[i][j]+a.num[i][k]*b.num[k][j])%MOD;
return c;
}
//=================================================
char getrch() {char ch=getchar();while ('9'<ch||ch<'0') ch=getchar();return ch;}
void make_fa(int *s)
{
fa[0]=fa[1]=0;
for (int i=2,j=0;i<=m;i++)
{
while (j&&s[j+1]!=s[i]) j=fa[j];
if (s[j+1]==s[i]) j++;
fa[i]=j;
}
}
Matrix power(Matrix w,int b) //矩阵快速幂
{
Matrix s;s.clear(m+1,m+1);
for (int i=0;i<=m;i++) s.num[i][i]=1;
while (b)
{
if (b&1) s=s*w;b>>=1;
if (b) w=w*w;
}
return s;
}
int main()
{
freopen("program.in","r",stdin);
freopen("program.out","w",stdout);
scanf("%d%d%d",&n,&m,&MOD);
for (int i=1;i<=m;i++) s[i]=getrch()-48;
make_fa(s);T.clear(m+1,m+1);
for (int i=0;i<=m-1;i++)
for (int now=0;now<=9;now++)
{
int j=i;
while (j&&s[j+1]!=now) j=fa[j];
if (s[j+1]==now) j++;
if (j!=m) T.num[i][j]=(T.num[i][j]+1)%MOD; //j==m时不满足不出现不吉利数字
}
ans.clear(1,m+1);ans.num[0][0]=1;ans=ans*power(T,n); //初始时f[0][0]=1
for (int i=0;i<=m-1;i++) sum=(sum+ans.num[0][i])%MOD;
printf("%d\n",sum);
return 0;
}