题意:给出 k 个字符,求用这 k 个字符组成的长度为 m 的字符串中不包含子串 s 的方案数。
分析:解决这个问题之前,得知道一些东西
①首先得知道矩阵可以解决这类经典问题:
给定一个有向图,问从A点恰好走k步(允许重复经过边)到达B点的方案数mod p的值
把给定的图转为邻接矩阵,即A(i,j)=1当且仅当存在一条边i->j。令C=A*A,那么C(i,j)=ΣA(i,k)*A(k,j),实际上就等于从点i到点j恰好经过2条边的路径数(枚举k为中转点)。类似地,C*A的第i行第j列就表示从i到j经过3条边的路径数。同理,如果要求经过k步的路径数,我们只需要二分求出A^k即可。
(以上的描述来自:https://blog.csdn.net/kopyh/article/details/51179617)
这样我们可以想清楚,若我们定义了一个合理的初始矩阵,并对这个矩阵做 m 次快速幂计算,则可以得到想要的结果,那么怎么构建成了最大的难点:
我们考虑一个初始矩阵 c, c[i][j] 表示当前构建的字符串末尾为子串 s 前 i 位,并且添加下一位后转移到字符串末尾变成子串 s 前 j 位的方案数。想一下:
①若当前添加的字符是子串 s[i] 的下一位,那么 j 是不是就等于 i+1;
②若当前添加的字符不是子串 s[i] 的下一位,那么就利用KMP的 next 数组往前继续匹配,直到匹配到一个位置pos,使得当前添加的字符是pos 的下一位 这样 i 就转移到了pos+1;
③或者匹配不到 pos ,即 当前添加字符不是子串 s 的任何前缀的下一位,那么 i 就转移到了 0;
这样初始矩阵就构建完成了,对它做 m 次幂运算,然后计算 即是所求。
(可能有人会疑惑为什么这样枚举计算出来的字符串不会包含子串 s ,虽然我们设置初始矩阵时 c[n-1][n] 可能不为零 ,但是我们计算的时候只计算 c[0~n-1][0~n-1],这样便不会考虑到任何 c[i][n] 和 c[n][i] 0<=i<=n ,即我们计算的时候已经省略了这种情况发生的可能!)
代码:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 55;
int m,n;
char str[N],ch[N];
int nxt[N];
struct Matrix
{
unsigned mat[N][N];
Matrix() { memset(mat,0,sizeof(mat)); }
};
Matrix operator * (Matrix a,Matrix b)
{
Matrix c=Matrix();
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
for(int k=0;k<n;k++)
c.mat[i][j]+=a.mat[i][k]*b.mat[k][j];
return c;
}
Matrix qpow(Matrix a,int x)
{
Matrix c=Matrix();
for(int i=0;i<n;i++) c.mat[i][i]=1;
while(x)
{
if(x&1) c=c*a;
a=a*a;
x>>=1;
}
return c;
}
void GetNext(char s[],int len,int nxt[])
{
nxt[0]=-1;
int j=0,k=-1;
while(j<len)
{
if(k==-1||s[j]==s[k])
{
++j,++k;
nxt[j]=k;
}
else
{
k=nxt[k];
}
}
}
int main()
{
int T;
scanf("%d", &T);
for(int cas=1;cas<=T;cas++)
{
scanf("%d%s%s",&m,ch,str);
n=strlen(str);
GetNext(str,n,nxt);
Matrix ans=Matrix();
for(int i=0;i<strlen(ch);i++) //按照给定字符集枚举要添加的字符
for(int j=0;j<n;j++)
{
int pos=j;
if(str[pos]==ch[i]) //一:当前添加的字符是 i 的下一位
{
ans.mat[pos][pos+1]++;
continue;
}
while(pos&&str[pos]!=ch[i]) pos=nxt[pos]; //二:第一种条件不满足,往前匹配
if(str[pos]==ch[i]) pos++; //二:判断是否往前匹配成功
ans.mat[j][pos]++;
}
ans=qpow(ans,m);
unsigned tot=0; //取模是 2^32,用 unsigned 自然溢出就可以了
for(int i=0;i<n;i++) tot+=ans.mat[0][i];
printf("Case %d: %u\n",cas,tot);
}
}