多串SAM
两串中相同的字串会在同一节点
每个节点搞个siz表示这个节点代表的字串出现了多少次
当然还要加一维,表示在哪个串中出现的次数
于是答案就是sigma (len[i]-len[par[i]])*siz[0][i]*siz[1][i]了
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#include<queue>
#include<vector>
#include<algorithm>
#include<map>
#include<set>
#include<stack>
#define rep(i,l,r) for(int i=l;i<=r;i++)
#define per(i,r,l) for(int i=r;i>=l;i--)
#define mmt(a,v) memset(a,v,sizeof(a))
#define tra(i,u) for(int i=head[u];i;i=e[i].next)
using namespace std;
typedef long long ll;
const int N=200000+5;
const int M=800000+5;
int par[M],ch[M][26],siz[2][M],len[M],root,sz,last;
void init(){last=root=sz=1;}
void extend(int x){
int p=last,np=++sz;
len[np]=len[p]+1;
for(;p&&!ch[p][x];p=par[p])ch[p][x]=np;
if(!p)par[np]=root;
else{
int q=ch[p][x];
if(len[q]==len[p]+1)par[np]=q;
else{
int nq=++sz;
len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
par[nq]=par[q];
par[q]=par[np]=nq;
for(;p&&ch[p][x]==q;p=par[p])ch[p][x]=nq;
}
}
last=np;
}
int deg[M];
void toposort(){
rep(i,1,sz)deg[par[i]]++;
queue<int>q;
rep(i,1,sz)if(!deg[i])q.push(i);
while(!q.empty()){
int u=q.front();q.pop();
siz[0][par[u]]+=siz[0][u];
siz[1][par[u]]+=siz[1][u];
deg[par[u]]--;
if(!deg[par[u]])q.push(par[u]);
}
}
char s[N];
int main(){
//freopen("a.in","r",stdin);
scanf("%s",s+1);
int n=strlen(s+1);
init();
int p=root;
rep(i,1,n){
extend(s[i]-'a');
p=ch[p][s[i]-'a'];
siz[0][p]++;
}
scanf("%s",s+1);
n=strlen(s+1);
last=root;p=root;
rep(i,1,n){
extend(s[i]-'a');
p=ch[p][s[i]-'a'];
siz[1][p]++;
}
toposort();
ll ans=0;
rep(i,1,sz)
if(par[i])
ans+=1LL*(len[i]-len[par[i]])*siz[0][i]*siz[1][i];
printf("%lld\n",ans);
return 0;
}