做这个题之前最好做一下POJ 2278(题解)
在POJ2278的基础上,
最终的答案就是26^1+26^2+......+26^L减去A^1+A^2+....+A^L
我们构造这么一个矩阵
|A , 1|
|0 , 1| 它 的n次方等于
|A^n , 1+A^1+A^2+....+A^(n-1)|
|0 , 1|
如果A是一个矩阵 那么1 和 0 也分别是[1 1 1... 1]T 和 [0 0 0.... 0]
那么结果是
|A^n , (1+A^1+A^2+....+A^(n-1))*[1 1 1 .. 1]T |
|[0 0 0 ...0] , 1 |
因为我们要求∑a[0][i],所以结果的第一行就是我们需要的,因为多加了一个一,减去就可以了。
#include <cmath>
#include <queue>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#define LL long long
#define ULL unsigned long long
#define INF 0x3f3f3f3f
using namespace std;
int n;char str[10];LL k;
struct matrix{ULL a[33][33];}X;
struct Trie
{
int next[33][26],fail[33],end[33];
int root,L;
int newnode()
{
for(int i = 0;i < 26;i++)
next[L][i] = -1;
end[L++] = 0;
return L-1;
}
void init(){L = 0;root = newnode();}
void insert(char buf[])
{
int len = strlen(buf);
int now = root;
for(int i = 0;i < len;i++)
{
if(next[now][buf[i]-'a'] == -1)
next[now][buf[i]-'a'] = newnode();
now = next[now][buf[i]-'a'];
}
end[now]=1;
}
void build()
{
queue<int>Q;
fail[root] = root;
for(int i = 0;i < 26;i++)
if(next[root][i] == -1)
next[root][i] = root;
else
{
fail[next[root][i]] = root;
Q.push(next[root][i]);
}
while( !Q.empty() )
{
int now = Q.front();
Q.pop();
if(end[fail[now]])
end[now]=1;
for(int i = 0;i < 26;i++)
if(next[now][i] == -1)
next[now][i] = next[fail[now]][i];
else
{
fail[next[now][i]]=next[fail[now]][i];
Q.push(next[now][i]);
}
}
}
}ac;
matrix multi(matrix A,matrix B,int n)
{
matrix C;
memset(C.a,0,sizeof(C.a));
int i,j,k;
for(i=0;i<n;i++)
for(j=0;j<n;j++)
{
for(k=0;k<n;k++)
C.a[i][j] += A.a[i][k]*B.a[k][j];
}
return C;
}
matrix quickly(matrix A,LL k,int n)
{
matrix ans;
memset(ans.a,0,sizeof(ans.a));
for(int i=0;i<n;i++)
ans.a[i][i]=1;
while(k)
{
if(k & 1) ans = multi(A,ans,n);
A = multi(A,A,n);
k>>=1;
}
return ans;
}
matrix getmatrix(matrix A)
{
memset(A.a,0,sizeof(A));
for(int i=0;i<ac.L;i++)
for(int j=0;j<26;j++)
if(!ac.end[ac.next[i][j]])
A.a[i][ac.next[i][j]]++;
for(int i=0;i<=ac.L;i++)
A.a[i][ac.L]=1;
return A;
}
int main()
{
int i;
while(~scanf("%d%I64d",&n,&k))
{
ac.init();
for(i=1;i<=n;i++)
{
scanf("%s",str);
ac.insert(str);
}
ac.build();
X = getmatrix(X);
X = quickly(X,k,ac.L+1);
ULL ans1 = 0;
for(i=0;i<=ac.L;i++) ans1 += X.a[0][i];
ULL ans2 = 0;
X.a[0][0]=26;X.a[0][1]=X.a[1][1]=1;X.a[1][0]=0;
X = quickly(X,k+1,2);
ans2 = X.a[0][1];
ans2 -= ans1;
printf("%I64u\n",ans2);
}
return 0;
}