题目链接
题意
给你 n n n 个模式串,求长度不超过 m m m 由小写字母组成且至少出现其中一个模式串的方案数。
思路
可以转化题意求所有可能字符串-不出现任意一个模式串的方案数。
那么可以建ac自动机,ac自动机走不到各个模式串匹配终点即不出现模式串。
相当于对ac自动机的nxt数组形成的图求由起点出发走不超过 m m m 步的方案数,这个问题可以用图论知识加矩阵快速幂解决。
注意两个字符串 abc
和b
,b
节点到不了ab
的b节点也无法到达,即倒着建fail树上的儿子节点应该全部到不了。
代码末放了组对拍找到的bug数据
代码
#include <stdio.h>
#include <queue>
#include <string.h>
using namespace std;
#define ll long long
#define llu unsigned long long
const ll maxn = 55;
struct mat {
llu m[maxn][maxn];
};
mat operator *(mat x, mat y) {
mat ret;
for(int i = 0; i < maxn; i++)
{
for(int j = 0; j < maxn; j++)
{
ret.m[i][j] = 0;
for(int k = 0; k < maxn; k++)
{
ret.m[i][j] = (x.m[i][k] * y.m[k][j] + ret.m[i][j]);
}
}
}
return ret;
}
mat pow_mat(mat a, int fuck) {
mat ret;
memset(ret.m,0,sizeof(ret.m));
for(int i = 0; i < maxn; i++) ret.m[i][i] = 1;
while(fuck) {
if(fuck&1) ret = ret*a;
a = a*a;
fuck >>= 1;
}
return ret;
}
const int N = 6*5+5;
const int M = 26;
struct ACAM {
int nxt[N][M], fail[N], val[N], n, book[N*M];
void init() {
memset(nxt,0,sizeof(nxt));
memset(fail,0,sizeof(fail));
memset(val,0,sizeof(val));
memset(book,0,sizeof(book));
n = 0;
}
void add(char *s) {
int len = strlen(s), now = 0;
for(int i = 0; i < len; ++i) {
int tmp = s[i]-'a';
if(!nxt[now][tmp]) nxt[now][tmp] = ++n;
now = nxt[now][tmp];
}
++val[now];
}
void getfail() {
queue<int> q;
for(int i = 0; i < 26; ++i) if(nxt[0][i]) fail[nxt[0][i]] = 0, q.push(nxt[0][i]);
while(!q.empty()) {
int u = q.front(); q.pop();
for(int i = 0; i < 26; ++i) {
if(nxt[u][i]) fail[nxt[u][i]] = nxt[fail[u]][i], q.push(nxt[u][i]);
else nxt[u][i] = nxt[fail[u]][i];
val[nxt[u][i]] += val[nxt[fail[u]][i]]; // 标记向下传递
}
}
}
}ac;
char s[10];
mat a, b, x, y;
void dfs(int u) {
ac.book[u] = 1;
for(int i = 0; i < 26; ++i) {
int v = ac.nxt[u][i];
if(ac.val[v]) continue;
++x.m[u][v+1+ac.n];
++y.m[u+1+ac.n][v+1+ac.n];
if(!ac.book[v]) dfs(v);
}
}
int main() {
int n, m;
while(~scanf("%d%d",&n,&m)) {
ac.init();
for(int i = 0; i < n; ++i) scanf("%s",s), ac.add(s);
ac.getfail();
memset(a.m,0,sizeof(a.m));
memset(b.m,0,sizeof(b.m));
memset(x.m,0,sizeof(x.m));
memset(y.m,0,sizeof(y.m));
a.m[0][1] = 26;
b.m[0][0] = 1; b.m[1][0] = 1; b.m[1][1] = 26;
for(int i = 0; i <= ac.n; ++i) y.m[i][i] = y.m[i+ac.n+1][i] = 1;
dfs(0);
x = x*pow_mat(y,m);
llu sum = 0;
for(int i = 0; i <= ac.n; ++i) sum += x.m[0][i];
printf("%llu\n",(a*pow_mat(b,m)).m[0][0]-sum);
}
return 0;
}
/*
5 3
qdjv
gps
h
bxzq
j
*/