题意:
给你n个字符串,任意选两个字符串,答案加上前一个字符串前缀和后一个字符串的后缀最长相同的部分长度的平方。问你答案是多少
AC自动机:
想练一下AC自动机,就用这个写了,有时间的话再用广义后缀自动机写一下。
首先我们肯定是要枚举所有的字符串,问题是将它当做前缀还是后缀。把当前的字符串当成后缀的话,你的fail指针是这么跳的时候,会很难维护是否有的字符串的前缀已经被访问过了:
那么这种请况我一下子还想不出来怎么解决,有可能要用lca什么的吗?然后就换了一种思路:把当前枚举的字符串当成前缀,去匹配所有字符串的后缀,这个时候ac自动机insert的时候,num是要在当前字符串的末尾端点才能+1,表示到了末尾,然后所有字符串insert之后,对fail边建反边,于是就形成了一颗树,此时可以dfs去把num加到前面的后缀中。
然后枚举每个字符串的时候,也是要求出它的nex数组,然后枚举的时候,nex[i+1]就是当前串的最长公共前后缀,后面的情况剪掉前面的重复情况即可。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=1e6+5,R=26;
const ll mod=998244353;
bool vis[N];
string s[N];
int nex[N];
void KMP(int x){
int len=s[x].length();
nex[0]=-1,nex[1]=0;
int i=1,j=0;
while(i<len&&j<len){
if(j==-1||s[x][i]==s[x][j])nex[++i]=++j;
else j=nex[j];
}
}
struct Tire{
int nxt[N][R],fail[N],ed[N],num[N],vis[N];
vector<int>son[N];
int rt,tot,cnt;
int newnode(){
for(int i=0;i<R;i++)nxt[tot][i]=-1;
//l[tot]=0;
return tot++;
}
void init(){
memset(num,0,sizeof(num));
tot=cnt=0;
rt=newnode();
}
int insert(int x){
int now=rt,len=s[x].length();
for(int i=0;i<len;i++){
int val=s[x][i]-'a';
if(nxt[now][val]==-1)nxt[now][val]=newnode();
//l[nxt[now][val]]=l[now]+1;
now=nxt[now][val];
}
num[now]++;
return now;
}
void dfs(int x){
for(int i:son[x]){
dfs(i);
num[x]+=num[i];
}
}
void build(){
queue<int>q;
fail[rt]=rt;
for(int i=0;i<R;i++){
if(nxt[rt][i]==-1)nxt[rt][i]=rt;
else {
fail[nxt[rt][i]]=rt;
q.push(nxt[rt][i]);
son[rt].push_back(nxt[rt][i]);
}
}
while(!q.empty()){
int now=q.front();q.pop();
for(int i=0;i<R;i++){
if(nxt[now][i]==-1)nxt[now][i]=nxt[fail[now]][i];
else {
fail[nxt[now][i]]=nxt[fail[now]][i];
son[nxt[fail[now]][i]].push_back(nxt[now][i]);
q.push(nxt[now][i]);
}
}
}
dfs(0);
}
ll query(int x){
int now=rt,len=s[x].length();
ll ans=0;
for(ll i=0;i<len;i++){
now=nxt[now][s[x][i]-'a'];
ans=(ans+1ll*num[now]*(i+1)%mod*(i+1)-1ll*num[now]*nex[i+1]%mod*nex[i+1])%mod;
if(ans<0)ans+=mod;
}
return ans;
}
}ac;
int p[N];
int main()
{
cin.tie(0);
ios::sync_with_stdio(false);
int n;
cin>>n;
ac.init();
for(int i=1;i<=n;i++)
cin>>s[i],ac.insert(i);
ac.build();
ll ans=0;
for(int i=1;i<=n;i++)
KMP(i),ans=(ans+ac.query(i))%mod;
printf("%lld\n",ans);
return 0;
}
广义后缀自动机:
后缀自动机就是单个串做后缀自动机,广义后缀自动机就是多个串做后缀自动机,有三种做广义后缀自动机的方法:
1.将所有串连起来,中间用特殊符号隔开,然后在经过一些神奇的操作,我不会太难了
2.每次做完之后将last置为1,这样就是从头开始新串。
3.和2差不多,但是它先对所有串构建字典树然后bfs建自动机。
我选择的是解法2,因为这个简单。
那么和普通后缀自动机的区别就在于,这里add的时候有一个特判:if(nxt[p][c])
也就是说如果存在这条边,那么就继续往下,不用建立新点。但是此时如果新进来的位置导致endpos类需要分裂,那么last就是分裂出来的那个点nq。但是在外面,last就是np。
那么这道题先做出来广义后缀自动机,然后我没再枚举一遍所有的字符串,同时在自动机的trie树上跳,并且把当前的字符串编号和长度平方存到这个点。最后跳完之后,计数器num[now]++表示这里有一个后缀。
接下来建一个parent树,跳一遍parent_tree就相当于枚举了一遍所有的后缀,从上往下跳parent_tree的话就是上面的是下面的后缀。然后这个点的f表示现在的长度平方是多少,我们对于每一个位置的长度平方需要减去前面所有位置的长度平方和,然后dfs2再跑一遍的时候,我们再加上前面的所有位置的平方和,经过这两步操作之后,每个位置的平方不变,但是我们就能够知道每个位置的平方是什么,因为它不止一个串,所以对于某些串,当前的位置其实不是那个串的endpos等价类,如果不减掉前缀的平方和,就可能会多加值,其实也就表示每个串的endpos等价类长度所掌握的区间。(我也有点混乱,大概是这个意思吧)
它卡空间,需要用链式前向星代替一个vector。(或许全改成int也能过?)
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define pil pair<int,ll>
const int N = 1e6+1000,M=1e5+5;
const ll mod=998244353;
struct GSAM{
int last,cnt,nxt[N*2][26],fa[N*2],l[N*2];
void init(){
last = cnt=1;
memset(nxt[1],0,sizeof nxt[1]);
fa[1]=l[1]=0;
}
int inline newnode(){
cnt++;
memset(nxt[cnt],0,sizeof nxt[cnt]);
fa[cnt]=l[cnt]=0;
return cnt;
}
void add(int c){
int p = last;
if(nxt[p][c]){
int q=nxt[p][c];
if (l[q]==l[p]+1){
last=q;
}
else{
int nq = newnode();
memcpy(nxt[nq],nxt[q],sizeof nxt[q]);
fa[nq] =fa[q];
l[nq] = l[p]+1;
fa[q] =nq;
while (nxt[p][c]==q){
nxt[p][c]=nq;
p = fa[p];
}
last=nq;
}
return ;
}
int np = newnode();
last = np;
l[np] =l[p]+1;
while (p&&!nxt[p][c]){
nxt[p][c] = np;
p = fa[p];
}
if (!p){
fa[np] =1;//表示这个字符在之前的字符串中没有出现过
}
else{
int q = nxt[p][c];
if (l[q]==l[p]+1){
fa[np] =q;
}
else{
int nq = newnode();
memcpy(nxt[nq],nxt[q],sizeof nxt[q]);
fa[nq] =fa[q];
l[nq] = l[p]+1;
fa[np] =fa[q] =nq;
while (nxt[p][c]==q){
nxt[p][c]=nq;
p = fa[p];
}
}
}
}
}sam;
string s[M];
struct node{
int to,next;
}e[N*2];
int cnt,head[N*2];
void add(int x,int y){
e[cnt].to=y;
e[cnt].next=head[x];
head[x]=cnt++;
}
vector<pil>vec[N*2];
int sum[N*2],f[N*2],num[N*2];
void dfs(int x){
for(pil &i:vec[x])i.second=1ll*(i.second-f[i.first]+mod)%mod;
for(pil i:vec[x])f[i.first]=1ll*(f[i.first]+i.second)%mod;
for(int i=head[x];~i;i=e[i].next){
int ne=e[i].to;
dfs(ne);
}
for(pil i:vec[x])f[i.first]=1ll*(f[i.first]-i.second+mod)%mod;
}
ll ans;
void dfs2(int x){
for(pil i:vec[x])sum[x]=1ll*(sum[x]+i.second)%mod;
ans=(ans+1ll*sum[x]*num[x])%mod;
for(int i=head[x];~i;i=e[i].next){
int ne=e[i].to;
sum[ne]=1ll*(sum[ne]+sum[x])%mod,dfs2(ne);
}
}
int main()
{
memset(head,-1,sizeof(head));
cin.tie(0);
ios::sync_with_stdio(false);
sam.init();
int n;
cin>>n;
for(int i=1;i<=n;i++){
cin>>s[i];
sam.last=1;
int len=s[i].length();
for(int j=0;j<len;j++)sam.add(s[i][j]-'a');
}
for(int i=1;i<=n;i++){
int now=1;
int len=s[i].length();
for(int j=0;j<len;j++){
now=sam.nxt[now][s[i][j]-'a'];
vec[now].push_back({i,1ll*(j+1)*(j+1)%mod});
}
num[now]++;
}
for(int i=2;i<=sam.cnt;i++)add(sam.fa[i],i);
dfs(1),dfs2(1);
printf("%lld\n",ans);
return 0;
}