补博客!
首先我们观察题目中给的那个求\(ans\)的方法,其实前两项没什么用处,直接\(for\)一遍就求得了
for (int i=1;i<=n;i++) ans=ans+i*(n-1);
那么我们考虑剩下的部分应该怎么求解!
首先这里有一个性质。对于任意两个后缀\(i,j\),他们的\(lcp\)长度是他们对应的\(rank\)之间的\(height\)的\(min\) (左开右闭)
或者这样说
\(lcp(i,j) = min(height[rank[i]+1],height[rank[i]+2].....,height[rank[j]]) 其中rank[i]<rank[j]\)
那么对于这个题,我们就可以直接维护出每个\(height\)作为最小值的区间,然后用他的区间个乘上贡献即可(但是具体这里求的时候需要仔细想想,因为那个左开右闭的区间,假设右边能选的端点是\(r[i]-l+1\),那么合法的右端点实际上是由\(i-l[i]+1\)因为,能覆盖到\(l[i]\)这个\(height\)的点实际上是\(l[i]-1\)。)
总之就是比较难理解啊
for (int i=1;i<=n;i++) ans=ans-2ll*(r[i]-i+1)*(i-l[i]+1)*height[i];
那么现在的问题就是应该怎么求\(l[i]和r[i]\)呢?
QWQ这貌似是单调栈的经典应用?
直接从左到右,从右到左扫两遍即可.
这里有一个很好的防止计算重复的方法
就是我们从左到右扫维护的栈是单调的。然后从右到左不单调(非严格)
或者说,一遍单调,一遍不单调,即可解决重复的问题了!
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<set>
#define mk makr_pair
#define ll long long
#define int long long
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return x*f;
}
const int maxn = 2e6+1e2;
struct Node{
int val,pos;
};
int wb[maxn],sa[maxn];
Node s[maxn];
int l[maxn],r[maxn];
int rk[maxn],h[maxn],height[maxn];
int tmp[maxn];
int n,m;
char a[maxn];
int ans;
void getsa()
{
int *x = rk,*y = tmp;
int s = 128;
int p = 0;
for (int i=1;i<=n;i++) x[i]=a[i],y[i]=i;
for (int i=1;i<=s;i++) wb[i]=0;
for (int i=1;i<=n;i++) wb[x[y[i]]]++;
for (int i=1;i<=s;i++) wb[i]+=wb[i-1];
for (int i=n;i>=1;i--) sa[wb[x[y[i]]]--] = y[i];
for (int j=1;p<n;j<<=1)
{
p=0;
for (int i=n-j+1;i<=n;i++) y[++p]=i;
for (int i=1;i<=n;i++) if (sa[i]>j) y[++p]=sa[i]-j;
for (int i=1;i<=s;i++) wb[i]=0;
for (int i=1;i<=n;i++) wb[x[y[i]]]++;
for (int i=1;i<=s;i++) wb[i]+=wb[i-1];
for (int i=n;i>=1;i--) sa[wb[x[y[i]]]--] =y[i];
swap(x,y);
p=1;
x[sa[1]]=1;
for (int i=2;i<=n;i++)
{
x[sa[i]] = (y[sa[i-1]]==y[sa[i]] && y[sa[i]+j]==y[sa[i-1]+j]) ? p : ++p;
}
s=p;
}
for (int i=1;i<=n;i++) rk[sa[i]]=i;
h[0]=0;
for (int i=1;i<=n;i++)
{
h[i]=max(h[i-1]-1,(long long)0);
while (i+h[i]<=n && sa[rk[i]-1]+h[i]<=n && a[i+h[i]]==a[sa[rk[i]-1]+h[i]]) h[i]++;
}
for (int i=1;i<=n;i++) height[i] = h[sa[i]];
}
int top;
signed main()
{
scanf("%s",a+1);
n = strlen(a+1);
getsa();
for (int i=1;i<=n;i++) ans=ans+i*(n-1);
l[1]=1;
s[++top].val=height[1];
s[1].pos=1;
for (int i=2;i<=n;i++)
{
while (top>=1 && s[top].val>=height[i]) top--;
if (!top) l[i]=1;
else l[i]=s[top].pos+1;
s[++top].val=height[i];
s[top].pos=i;
}
memset(s,0,sizeof(s));
top=1;
r[n]=n;
s[top].val=height[n];
s[top].pos=n;
for (int i=n-1;i>=1;i--)
{
while (top>=1 && s[top].val>height[i]) top--;
if (!top) r[i]=n;
else r[i]=s[top].pos-1;
s[++top].val=height[i];
s[top].pos=i;
}
for (int i=1;i<=n;i++) ans=ans-2ll*(r[i]-i+1)*(i-l[i]+1)*height[i];
cout<<ans;
return 0;
}