题目地址:http://acm.hdu.edu.cn/showproblem.php?pid=2243
题目意思:
给你n个字符串,然后给你一个长度L
问你在长度不超过L的所有字符串中(a~z)有多少个至少含有一个子串
意思很明确了,下面说解法,这题和POJ2778很类似,详见:http://blog.csdn.net/dr5459/article/details/8971626
那我们就建立一个自动机,通过和POJ2778类似的方式把跳转矩阵A求出来
然后求A^1+A^2+A^3+A^4+...+A^L求出来
但是这个太麻烦了,我们要引入一个公式:
通过这个方法就可以求出某个矩阵的连续指数和,然后减去E就可以了
那么我们就可以求L+1次方就可以,因为我们最后只是求的是矩阵的第一列,所以在答案中-1就行
那么这是不可能的,我们在求出26^1+26^2+...+26^n就可以了
一相减即为结果。因为要模2^64次方,我们用unsigned long long就可以
另外杭电上要用%I64u,不然会有各种错误。
代码:
#include<cstdio>
#include<cmath>
#include<cstring>
#include<queue>
#include<algorithm>
using namespace std;
#define ULL unsigned long long
const int maxnode = 6*6+5;
const int size=26;
struct AC
{
int ch[maxnode][size];
int f[maxnode];
bool val[maxnode];
int sz;
void init()
{
memset(ch[0],-1,sizeof(ch[0]));
sz=1;
val[0]=false;
}
int idx(char c)
{
return c-'a';
}
void insert(char *s)
{
int len = strlen(s);
int u=0;
for(int i=0;i<len;i++)
{
int c=idx(s[i]);
if(-1 == ch[u][c])
{
memset(ch[sz],-1,sizeof(ch[sz]));
val[sz]=false;
ch[u][c]=sz++;
}
u=ch[u][c];
}
val[u]=true;
}
void getfail()
{
queue<int> q;
for(int i=0;i<size;i++)
{
if(-1 != ch[0][i])
{
f[ch[0][i]] = 0;
q.push(ch[0][i]);
}
else
ch[0][i] = 0;
}
while(!q.empty())
{
int r=q.front();
q.pop();
//这里我们是要把trie变为一棵跳转的树,不是一棵匹配树
if(val[f[r]])
val[r]=true;
for(int i=0;i<size;i++)
{
int &v=ch[r][i];
if(-1 != v)
{
q.push(v);
f[v]=ch[f[r]][i];
}
else
v=ch[f[r]][i];
}
}
}
};
AC ac;
struct maxtrix
{
ULL m[32*2][32*2];
};
void maxtrixmul(maxtrix a,maxtrix b,maxtrix &ans,int n)
{
memset(ans.m,0,sizeof(ans.m));
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
for(int k=0;k<n;k++)
{
ans.m[i][j]+=a.m[i][k]*b.m[k][j];
}
}
}
}
void maxtrixpower(ULL n,maxtrix a,maxtrix &ans,int len)
{
if(n==1)
return;
n--;
while(n)
{
if(n&1)
{
maxtrixmul(a,ans,ans,len);
}
maxtrixmul(a,a,a,len);
n=n/2;
}
}
void build(maxtrix &ans,int n)
{
memset(ans.m,0,sizeof(ans.m));
for(int i=0;i<n;i++)
{
for(int j=0;j<size;j++)
{
if(!ac.val[i] && !ac.val[ac.ch[i][j]])
{
ans.m[i][ac.ch[i][j]]++;
}
}
}
}
ULL sum_26_n(ULL l)
{
maxtrix tmp;
tmp.m[0][0] = 26;
tmp.m[0][1] = tmp.m[1][1] = 1;
tmp.m[1][0] = 0;
maxtrixpower(l+1,tmp,tmp,2);
return tmp.m[0][1] - 1;
}
ULL sum_a_n(ULL l,maxtrix tiaozhuan)
{
maxtrix tmp;
memset(tmp.m,0,sizeof(tmp.m));
int n = ac.sz;
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
tmp.m[i][j] = tiaozhuan.m[i][j];
}
for(int i=0;i<2*n;i++)
{
for(int j=n;j<2*n;j++)
{
if(i<n && i==j-n)
tmp.m[i][j] = 1;
if(i>=n && i==j)
tmp.m[i][j] = 1;
}
}
maxtrixpower(l+1,tmp,tmp,ac.sz*2);
ULL ans=0;
for(int i=n;i<2*n;i++)
ans+=tmp.m[0][i];
return ans-1;
}
int main()
{
ULL m,l;
while(~scanf("%I64u%I64u",&m,&l))
{
ac.init();
char op[10];
for(int i=0;i<m;i++)
{
scanf("%s",op);
ac.insert(op);
}
ac.getfail();
ULL sum1 = sum_26_n(l);
maxtrix tiaozhuan;
build(tiaozhuan,ac.sz);
ULL sum2 = sum_a_n(l,tiaozhuan);
printf("%I64u\n",sum1-sum2);
}
return 0;
}