题意:
给你一个字符串T,和n个字符串S,让你求出下面给出的公式F,F表示将 S i S_i Si和 S j S_j Sj拼接后在 T T T中的出现次数
∑ i = 1 n ∑ j = 1 n F ( t , s i + s j ) \sum_{i=1}^n\sum_{j=1}^nF(t,s_i+s_j) i=1∑nj=1∑nF(t,si+sj)
题解:
用n个字符串S正反两个AC自动机,第一个AC自动机处理以T的第i个字符结尾的串的种类数,第二个则能处理以T的第i+1个字符结尾的串的种类数,(第二个是S串全部反转建立的,相当于从右往左匹配),最后统计即可。
AC代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int MAXN = 2e5+50;
struct AC_Automaton{
int nxt[MAXN][26],fail[MAXN],cnt[MAXN],ans[MAXN];
int tot=1;
inline void Insert(char *s,int len){
int root=0;
for(int i=0;i<len;i++){
if(!nxt[root][s[i]-'a']) nxt[root][s[i]-'a'] = ++tot;
root = nxt[root][s[i]-'a'];
}
cnt[root]++;
}
inline void Build(){
queue<int> que;
for(int i=0;i<26;i++)
if(nxt[0][i])
cnt[nxt[0][i]] += cnt[0],que.push(nxt[0][i]);
while(!que.empty()){
int u = que.front(); que.pop();
for(int i=0;i<26;i++){
if(nxt[u][i]){
fail[nxt[u][i]] = nxt[fail[u]][i];
cnt[nxt[u][i]] += cnt[fail[nxt[u][i]]];
que.push(nxt[u][i]);
}
else nxt[u][i] = nxt[fail[u]][i];
}
}
}
inline void Solve(char *s,int len){
int root=0;
for(int i=0;i<len;i++)
root=nxt[root][s[i]-'a'],ans[i]=cnt[root];
}
}Ac1,Ac2;
char s[MAXN],t[MAXN];
int main(){
scanf("%s",s);
int ls=strlen(s);
int n; scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%s",t);
int lt=strlen(t);
Ac1.Insert(t,lt);
reverse(t,t+lt);
Ac2.Insert(t,lt);
}
Ac1.Build(); Ac2.Build();
Ac1.Solve(s,ls);
reverse(s,s+ls);
Ac2.Solve(s,ls);
LL res = 0;
for(int i=0;i<ls-1;i++)
res += 1LL*Ac1.ans[i]*Ac2.ans[ls-i-2];
printf("%lld\n",res);
return 0;
}