#include <bits/stdc++.h>
using namespace std;
const int maxn= 1e2 ;
int _next[ maxn] ;
void initnext ( string str) {
int i, j;
_next[ 0 ] = - 1 ;
int len = str. size ( ) ;
for ( i= 1 ; i< len; i++ ) {
j= _next[ i- 1 ] ;
while ( str[ j+ 1 ] != str[ i] && j>= 0 ) j= _next[ j] ;
if ( str[ i] == str[ j+ 1 ] ) _next[ i] = j+ 1 ;
else _next[ i] = - 1 ;
}
}
int kmp ( string str, string ptr) {
int slen= str. size ( ) , s= 0 ;
int plen= ptr. size ( ) , p= 0 ;
initnext ( ptr) ;
while ( s< slen&& p< plen) {
if ( str[ s] == ptr[ p] ) s++ , p++ ;
else {
if ( p== 0 ) s++ ;
else p= _next[ p- 1 ] + 1 ;
}
}
return p== plen? s- plen: - 1 ;
}
int main ( ) {
string a= "aaassdewaaaaswed" ;
string b= "aaaas" ;
string c= "wef" ;
cout<< kmp ( a, b) << "\n" ;
cout<< kmp ( a, c) << "\n" ;
}
输出
8
-1
版本二
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e3 + 3 ;
char str[ maxn] , pattern[ maxn] ;
int fail[ maxn] , cnt;
void getFail ( char * p, int plen) {
fail[ 0 ] = 0 ; fail[ 1 ] = 0 ;
for ( int i= 1 ; i< plen; i++ ) {
int j= fail[ i] ;
while ( j&& p[ i] != p[ j] ) j= fail[ j] ;
fail[ i+ 1 ] = ( p[ i] == p[ j] ) ? j+ 1 : 0 ;
}
}
int kmp ( char * s, char * p) {
int last = - 1 ;
int slen = strlen ( s) , plen= strlen ( p) ;
getFail ( p, plen) ;
int j = 0 ;
for ( int i= 0 ; i< slen; i++ ) {
while ( j&& s[ i] != p[ j] ) j= fail[ j] ;
if ( s[ i] == p[ j] ) j++ ;
if ( j== plen) {
printf ( "%d\n" , i+ 1 - plen) ;
if ( i- last>= plen) {
cnt++ ;
last= i;
}
}
}
}
int main ( ) {
while ( ~ scanf ( "%s%s" , str, pattern) ) {
cnt = 0 ;
kmp ( str, pattern) ;
printf ( "%d\n" , cnt) ;
}
}