Description
兔子们在玩字符串的游戏。首先,它们拿出了一个字符串集合S,然后它们定义一个字
符串为“好”的,当且仅当它可以被分成非空的两段,其中每一段都是字符串集合S中某个字符串的前缀。
比如对于字符串集合{“abc”,”bca”},字符串”abb”,”abab”是“好”的(”abb”=”ab”+”b”,abab=”ab”+”ab”),而字符串“bc”不是“好”的。
兔子们想知道,一共有多少不同的“好”的字符串。
Input
第一行一个整数n,表示字符串集合中字符串的个数
接下来每行一个字符串
Output
一个整数,表示有多少不同的“好”的字符串
Sample Input
2
ab
ac
Sample Output
9
HINT
1<=n<=10000,每个字符串非空且长度不超过30,均为小写字母组成。
题意
给出一个字符串集合,问有几个字符串可以分成非空的两部分,使两部分都是字符串集合中的某个字符串的前缀.
做法
首先我们可以先对字符串集合中的所有字符建一棵字典树,这样每一个节点都可以表示一个前缀,那么在不算重复的情况下,字典树节点数(不算根节点)的平方即为合法字符串的数量.
下面考虑重复,如果一个字符串被算了多次,那么就一定有多种方法将它分为两个前缀,我们用每个切割点来代表每种方法,这样两个切割点a,b之间的这一段中心串就既是一个前缀的后缀,也同时是另一个前缀的前缀,也可以表示为root->b,a->b,b->endl,a->endl都是字符串集合中的前缀。
我们可以把字典树改为AC自动机,并且枚举a->endl这一段,那么它的fail就可以表示这一段字符串去掉最短的前缀后仍然是原来的前缀,因为a,b是相邻的两个分割点,因而在a,b之间没有合法的分割点,所以这一段字符串也可以表示b->endl,那么我们可以计算出所有以a->b为后缀的合法前缀的数量(可以预处理),这个值就是以a->b为中心串且后缀为a->endl对答案的贡献,对于一个分割点为a,b,c,d,e…….的字符串,如果按上述方法减去a,b,c,d,e…..(除了最后一个切割点)到末尾的这段字符串对答案的贡献,就可以使该字符串对答案的贡献仅为1,从而达到去重的目的。
在实际操作中,我们只需要枚举a->endl这一段(也就是所有fail不为根节点的节点),它的fail为b->endl这一段,那么沿着a的fa向上找len[b->endl]次后就表示a->b,减去以它为后缀的前缀数量,用字典树大小的平方减去所有贡献后即为最终答案。
代码
#include<iostream>
#include<cstdio>
#include<queue>
#define ll long long
#define N 10010
#define M 300100
using namespace std;
ll n,tt,ans,cnt[2][N];
string str;
struct Node
{
ll son[30],fail,cnt,len,fa;
}node[M];
queue<ll>que;
inline void in()
{
ll i,t,now=0,u;
for(i=0,t=str.size();i<t;i++)
{
u=str[i]-'a';
if(!node[now].son[u])
{
node[now].son[u]=++tt;
node[tt].len=node[now].len+1;
node[tt].fa=now;
}
now=node[now].son[u];
}
}
inline void build()
{
ll i,now,t,q,k;
for(i=0;i<26;i++)
{
if(node[0].son[i])
{
node[node[0].son[i]].fail=0;
que.push(node[0].son[i]);
}
}
for(;!que.empty();)
{
q=que.front();
que.pop();
for(i=0;i<26;i++)
{
if(!node[q].son[i]) continue;
for(k=node[q].fail;k&&!node[k].son[i];k=node[k].fail);
node[node[q].son[i]].fail=(node[k].son[i])?node[k].son[i]:0;
for(k=node[node[q].son[i]].fail;k;k=node[k].fail) node[k].cnt++;
que.push(node[q].son[i]);
}
}
}
inline int up(int u,int v)
{
for(;v--;)
{
u=node[u].fa;
}
return u;
}
int main()
{
ios::sync_with_stdio(0);
ll i,j,k,l1,l2;
cin>>n;
for(i=1;i<=n;i++)
{
cin>>str;
in();
}
build();
ans=tt*tt;
for(i=1;i<=tt;i++)
{
if(node[i].fail)
ans-=node[up(i,node[node[i].fail].len)].cnt;
}
cout<<ans;
}