Censored!
这个题目写了一天多的时间,开始的时候,虽然对主要的思路清楚了,但是有些细节问题没有想清楚。在解DP的时候,开始写的时候没有对状态转化搞的很清楚,所以中间出现错误然后又调试了很久,这就浪费了很长的时间。以后还是想清楚细节之后再去敲代码,得训练一下写代码的思维,如何思考更有利于写代码,而不仅仅是注重于问题的解法是什么。
说下这个题目我个人的想法:先建立trie树,然后构建fail指针,在就是对于状态的转化问题。我们知道,对于给定的字符集,那么如果给定字符串的长度,那么我们可以将对应的可能构成的字符串按结尾字符分为如果干类。而在这里,我们也是按照相似的分字符串的方式。我们可以考虑给定的字符串,按其匹配情况来分类,那么如果构成的字典树中有m个节点,那么很明显,我们可以按照这个将对应的字符串分入到这m个集合中,也就是说一个字符串必定处于这m个集合中的某一个,这样我们就可以将所有的字符串划分为m个类,然后总数就是这m个类中的个数之和。那么我们就可以写出DP的状态表示了,dp[i][j],表示长度为i,终止于状态j(节点编号为j)的字符串的个数。
然后是DP的状态转移方程:如果我知道长度为i-1的字符串此时处于状态k,那么当我从给定的字符集中取出一个字符加入到字符串尾的时候,它可能跳转到那些状态,可能回到初始状态,也可能跳转到下一个状态。这时我们就需要AC自动机来分析了,我们考虑如果添加某个字符c,当前状态是k,那么如果:c为状态k可以识别的(这里的意思是:对于构建的trie树有后继节点),那么就进入k的下一个状态,否则的话,我们就只能考虑k状态的失败指针指向的位置了,看它能否识别该字符,如果不能的话,继续找fail指针,直到找到根节点,对于k不能识别的情况,我们可以如下考虑:因为到达状态k后会有一部分序列与fail[k]是相同的,那么我们可以识别到下一个状态fail[k]的后继节点。这样的话,我们就可以写出了状态转移方程:dp[i][j] = sum(dp[i-1][k]),k是所有可能的状态集合,如果k能识别字符i。
但是这里要求了,我们不能将那些危险序列包含进来,这样的话,我们可以对dp[i][j]进行一个限制,要求dp[i][j]不包含危险序列,或者说是未经过危险序列。处理的时候,我们不能走到叶节点,同样的,对于如果有相同的公共序列的单词间,我们都应该要求这写特殊的节点不能经过,我们只需要做上标记即可。然后当状态转移的时候,我们需要判断改状态是不是合法即可。
#include <cstdio>
#include <cstring>
#include <queue>
using namespace std ;
#define maxn 110
const int mod = 10000 ;
int next[maxn][55] ;
int fail[maxn] ;
int flag[maxn] ;
int dp[55][maxn][30] ;
char alpa[55] ;
int ncount ;
int n , m , p ;
void Add(int numa[] , int numb[]){
for(int i = 0 ; i < 30 ; i ++){
numa[i] += numb[i] ;
numa[i + 1] += numa[i]/mod ;
numa[i] = numa[i]%mod ;
}
}
void print(int ans[]){
bool flag(0) ;
for(int i = 29 ; i >=0 ; i --){
if(flag){
printf("%04d" , ans[i]) ;
}
else if(ans[i]){
printf("%d" , ans[i]) ;
flag = 1 ;
}
}
if(!flag)
printf("0") ;
printf("\n") ;
}
int hash(char c){
int i(0) ;
while(alpa[i]){
if(alpa[i]==c)
return i ;
i ++ ;
}
return -1 ;
}
void init(){
memset(next , 0 , sizeof(next)) ;
memset(flag , 0 , sizeof(flag)) ;
memset(fail , 0 , sizeof(fail)) ;
memset(dp , 0 , sizeof(dp)) ;
ncount = 0 ;
}
void insert(char * word){
char *p = word ;
int b ;
int c(0) ;
while(*p){
b = hash(*p) ;
if(!next[c][b]){
next[c][b] = ++ncount ;
}
c = next[c][b] ;
p ++ ;
}
flag[c] = 1 ;
}
bool read(){
if(scanf("%d%d%d\n" , &n , &m , &p)==EOF)
return 0 ;
gets(alpa) ;
init() ;
char str[55] ;
for(int k = 1 ; k <= p ; k ++){
gets(str) ;
insert(str) ;
}
return 1 ;
}
void build_ac(){
queue<int> Q ;
int cur = 0 ;
Q.push(cur) ;
fail[cur] = 0 ;
int child , k , tmp ;
while(!Q.empty()){
cur = Q.front() ;
Q.pop() ;
for( k = 0 ; k < n ; k ++){
child = next[cur][k] ;
if(child){
Q.push(child) ;
if(cur == 0)
fail[child] = 0 ;
else{
tmp = fail[cur] ;
while(tmp && !next[tmp][k])
tmp = fail[tmp] ;
if(next[tmp][k])
fail[child] = next[tmp][k] ;
else
fail[child] = 0 ;
}
if(flag[ fail[child] ])
flag[ child ] = 1 ;
}
else{
next[cur][k] = next[fail[cur]][k] ;
}
}
}
}
void solve(){
build_ac() ;
dp[0][0][0] = 1 ;
for(int i = 1 ; i <= m ; i ++){
for(int k = 0 ; k <= ncount ; k ++){
if(flag[k])
continue ;
for(int j = 0 ; j < n ; j ++){
int t = next[k][j] ;
if(flag[t])
continue ;
Add(dp[i][t] , dp[i-1][k]) ;
}
}
}
int ans[30] ;
memset(ans , 0 , sizeof(ans)) ;
for(int i = 0 ; i <= ncount ; i ++){
if(flag[i])
continue ;
Add(ans , dp[m][i]) ;
}
print(ans) ;
}
int main(){
while( read() ){
solve();
}
return 0 ;
}