链接
题解
把两个字符串拼接之后,计算出后缀数组
假设 t t t串的排名是 r r r,那么所有排在 r r r前面的后缀的字典序都比 t t t小,这些后缀的所有前缀的字典序也比 t t t小
排名 > r >r >r的后缀,假设和 t t t的 l c p lcp lcp是 L L L,那么如果 l c p lcp lcp长度小于 t t t的长度,这个 l c p lcp lcp的所有前缀都要算进答案,否则取 ∣ t ∣ − 1 |t|-1 ∣t∣−1
代码
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define iinf 0x3f3f3f3f
#define linf (1ll<<60)
#define eps 1e-8
#define maxn 400010
#define cl(x) memset(x,0,sizeof(x))
#define rep(i,a,b) for(i=a;i<=b;i++)
#define drep(i,a,b) for(i=a;i>=b;i--)
#define em(x) emplace(x)
#define emb(x) emplace_back(x)
#define emf(x) emplace_front(x)
#define fi first
#define se second
#define de(x) cerr<<#x<<" = "<<x<<endl
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
ll read(ll x=0)
{
ll c, f(1);
for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f;
for(;isdigit(c);c=getchar())x=x*10+c-0x30;
return f*x;
}
struct SuffixArray
{
int sa[maxn], rank[maxn], ws[maxn], wv[maxn], wa[maxn], wb[maxn], height[maxn], N;
bool cmp(int *r, int a, int b, int l){return r[a]==r[b] and r[a+l]==r[b+l];}
void build(int *r, int n, int m)
{
N=n;
n++;
int i, j, k=0, p, *x=wa, *y=wb, *t;
for(i=0;i<m;i++)ws[i]=0;
for(i=0;i<n;i++)ws[x[i]=r[i]]++;
for(i=1;i<m;i++)ws[i]+=ws[i-1];
for(i=n-1;i>=0;i--)sa[--ws[x[i]]]=i;
for(p=j=1;p<n;j<<=1,m=p)
{
for(p=0,i=n-j;i<n;i++)y[p++]=i;
for(i=0;i<n;i++)if(sa[i]>=j)y[p++]=sa[i]-j;
for(i=0;i<n;i++)wv[i]=x[y[i]];
for(i=0;i<m;i++)ws[i]=0;
for(i=0;i<n;i++)ws[wv[i]]++;
for(i=1;i<m;i++)ws[i]+=ws[i-1];
for(i=n-1;i>=0;i--)sa[--ws[wv[i]]]=y[i];
for(t=x,x=y,y=t,p=1,i=1,x[sa[0]]=0;i<n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
for(i=0;i<n;i++)rank[sa[i]]=i;
for(i=0;i<n-1;height[rank[i++]]=k)
for(k?k--:0,j=sa[rank[i]-1];r[i+k]==r[j+k];k++);
}
}SA;
char s[maxn], t[maxn];
int r[maxn];
int main()
{
int n, m, i, j;
ll ans=0;
scanf("%s%s",s,t);
n=strlen(s), m=strlen(t);
rep(i,0,n-1)r[i]=s[i];
r[n]='|';
rep(i,0,m-1)r[n+1+i]=t[i];
SA.build(r,n+m+1,300);
// rep(i,0,SA.N)
// {
// rep(j,SA.sa[i],SA.N-1)putchar(r[j]);
// putchar(10);
// }
rep(i,0,SA.rank[n+1])
{
if(SA.sa[i]<n)
{
// printf("%d\n",i);
ans += n-SA.sa[i];
}
}
int mn=iinf;
rep(i,i,SA.N)
{
mn = min(mn,SA.height[i]);
if(SA.sa[i]<n)
{
ans += min(mn,m-1);
}
}
printf("%lld",ans);
return 0;
}