Description
兔子们在玩字符串的游戏。首先,它们拿出了一个字符串集合S,然后它们定义一个字
符串为“好”的,当且仅当它可以被分成非空的两段,其中每一段都是字符串集合S中某个字符串的前缀。
比如对于字符串集合{"abc","bca"},字符串"abb","abab"是“好”的("abb"="ab"+"b",abab="ab"+"ab"),而字符串“bc”不是“好”的。
兔子们想知道,一共有多少不同的“好”的字符串。
Input
第一行一个整数n,表示字符串集合中字符串的个数
接下来每行一个字符串
Output
一个整数,表示有多少不同的“好”的字符串
Sample Input
2
ab
ac
ab
ac
Sample Output
9
HINT
1<=n<=10000,每个字符串非空且长度不超过30,均为小写字母组成。
Source
考虑一个合法的串,它在AC自动机上匹配的路径是唯一的
假设dp[i][j][k]表示当前长度为i,在AC自动机的j结点,然后第一次失配的长度为k的方案数
合法答案必须dep[j] + k > i, 否则就会出现中间失配但是被忽略的情况
因为一个串如果继续失配下去那它能选择的长度就变短了
然后发现只需要记录i - k即可,dp[i][j]就够了
特判掉i - k = 0, 即fail[j] != root的情况
#include <bits/stdc++.h>
#define xx first
#define yy second
#define mp make_pair
#define pb push_back
#define fill( x, y ) memset( x, y, sizeof x )
#define copy( x, y ) memcpy( x, y, sizeof x )
using namespace std;
typedef long long LL;
typedef pair < int, int > pa;
const int MAXN = 300005;
int fail[MAXN], nxt[MAXN][26], dep[MAXN], id = 1, rt = 1, n, q[MAXN], ql, qr;
LL dp[35][MAXN], ans;
bool f[MAXN][26];
char ch[35];
inline void ins()
{
int cur = rt, len = strlen( ch + 1 );
for( int i = 1 ; i <= len ; i++ )
{
if( !nxt[ cur ][ ch[ i ] - 'a' ] ) dep[ nxt[ cur ][ ch[ i ] - 'a' ] = ++id ] = i, f[ cur ][ ch[ i ] - 'a' ] = 1;
cur = nxt[ cur ][ ch[ i ] - 'a' ];
}
}
inline void init()
{
q[ ++qr ] = rt;
while( ql ^ qr )
{
int x = q[ ++ql ];
for( int i = 0 ; i < 26 ; i++ )
if( nxt[ x ][ i ] ) fail[ q[ ++qr ] = nxt[ x ][ i ] ] = nxt[ fail[ x ] ][ i ];
else nxt[ x ][ i ] = nxt[ fail[ x ] ][ i ];
}
}
inline void solve()
{
for( int i = 2 ; i <= id ; i++ ) if( fail[ i ] ^ rt ) ans++;
for( int i = 1 ; i <= id ; i++ )
for( int j = 0 ; j < 26 ; j++ ) if( !f[ i ][ j ] && nxt[ i ][ j ] != rt )
dp[ 1 ][ nxt[ i ][ j ] ]++;
for( int i = 1 ; i < 35 ; i++ )
for( int j = 1 ; j <= id ; j++ ) if( dp[ i ][ j ] )
{
ans += dp[ i ][ j ];
for( int k = 0 ; k < 26 ; k++ )
if( dep[ nxt[ j ][ k ] ] > i )
dp[ i + 1 ][ nxt[ j ][ k ] ] += dp[ i ][ j ];
}
}
int main()
{
#ifdef wxh010910
freopen( "data.in", "r", stdin );
#endif
for( int i = 0 ; i < 26 ; i++ ) nxt[ 0 ][ i ] = 1;
scanf( "%d", &n );
for( int i = 1 ; i <= n ; i++ ) scanf( "%s", ch + 1 ), ins();
init();
solve();
cout << ans << endl;
return 0;
}