串
串
(prefix.cpp/c/pas)
题目描述
兔子们在玩字符串的游戏。首先,它们拿出了一个字符串集合
S
S
S,然后它们定义一个字符串为“好”的,当且仅当它可以被分成非空的两段,其中每一段都是字符串集合
S
S
S 中某个字符串的前缀。比如对于字符串集
"
a
b
c
"
,
"
b
c
a
"
{"abc", "bca"}
"abc","bca",字符串
"
a
b
b
"
,
"
a
b
a
b
"
"abb","abab"
"abb","abab"是“好”的
(
"
a
b
b
"
=
"
a
b
"
+
"
b
"
,
a
b
a
b
=
"
a
b
"
+
"
a
b
"
)
("abb" ="ab"+"b", abab = "ab" + "ab")
("abb"="ab"+"b",abab="ab"+"ab"),而字符串
“
b
c
”
“bc”
“bc”不是“好”的。兔子们想知道,一共有多少不同的“好”的字符串。
输入格式
第一行一个整数
n
n
n,表示字符串集合中字符串的个数
接下来每行一个字符串
输出格式
一个整数,表示有多少不同的“好”的字符串
样例输入
2
ab
ac
样例输出
9
数据范围与约定
对于 20%的数据,
1
≤
n
≤
200
1 \le n \le 200
1≤n≤200
对于 50%的数据,
1
≤
n
≤
2000
1 \le n \le 2000
1≤n≤2000
对于 100%的数据,
1
≤
n
≤
10000
1 \le n \le 10000
1≤n≤10000,每个字符串非空且长度不超过
30
30
30,均为小写字母
组成。
思路:
对所有不同前缀排列组合,一共能组成
n
u
m
2
num^2
num2个字符串,
n
u
m
num
num为
T
i
r
e
Tire
Tire树的节点个数。
然而发现有重复的字符串:
"
a
b
a
"
+
"
a
b
"
=
"
a
b
a
a
b
"
"aba"+"ab"="abaab"
"aba"+"ab"="abaab"
"
a
b
a
a
"
+
"
b
"
=
"
a
b
a
a
b
"
"abaa"+"b"="abaab"
"abaa"+"b"="abaab"
考虑去重。
对于两个相同字符串,如下组成:
发现必有两个不同前缀(绿,橙),一个的后缀与另一个的前缀相同。一个字符串的多余贡献为前缀能与它的不同后缀匹配上的字符串的个数。
A C AC AC自动机的 f a i l fail fail数组为不同字符串匹配前后缀,那么一个字符串的一个后缀多余贡献即为此后缀开始节点在 f a i l fail fail树上的子树大小。
代码:
#include<bits/stdc++.h>
using namespace std;
#define in Read()
#define LL long long
inline int in{
int s=0,f=1;char x;
for(x=getchar();x<'0'||x>'9';x=getchar()) if(x=='-') f=-1;
for( ;x>='0'&&x<='9';x=getchar()) s=(s<<1)+(s<<3)+(x&15);
return f==1?s:-s;
}
const int A=1e4+5;
int n;
char a[40];
int tree[30*A][26],f[30*A],fail[30*A];
int tot;
LL sum[30*A];
void build(){
int p=0,len=strlen(a);
for(int i=0;i<len;i++){
if(!tree[p][a[i]-'a']){
tree[p][a[i]-'a']=++tot;
f[tot]=p;
}
p=tree[p][a[i]-'a'];
}
return;
}
void getfail(){
queue <int> q;
int p=0;
for(int i=0;i<26;i++){
p=tree[0][i];
if(!p) continue;
fail[p]=0;
q.push(p);
}
while(!q.empty()){
int x=q.front();q.pop();
for(int i=0;i<26;i++){
p=tree[x][i];
if(!p){
tree[x][i]=tree[fail[x]][i];
continue;
}
q.push(p);
fail[p]=tree[fail[x]][i];
}
}
return;
}
void solve(){
LL ans=0;
for(int i=1;i<=tot;i++)
for(int j=fail[i];j;j=fail[j])
sum[j]++;
for(int i=1;i<=tot;i++)
if(fail[i]){
int j=i,k=fail[i];
while(k){
j=f[j];
k=f[k];
}
ans+=sum[j];
}
printf("%lld\n",(LL)tot*tot-ans);
}
signed main(){
n=in;
for(int i=1;i<=n;i++){
scanf("%s",a);
build();
}
getfail();
solve();
return 0;
}