题目来源:http://www.lydsy.com/JudgeOnline/problem.php?id=3238
Description
Input
一行,一个字符串S
Output
一行,一个整数,表示所求值
Sample Input
cacao
Sample Output54
HINT
2<=N<=500000,S由小写英文字母组成
后缀数组+单调栈
贴一下个人感觉不错的一个题解:
http://www.cnblogs.com/Tunix/p/4211675.html
引用:
这个题目的要求……我们很明显可以直接预处理出来T(i)+T(j)的总和,为n*(n-1)(n+1)/2。(应该挺容易推的吧?自己画一下(样例):当i为1的时候,j可以为2、3、4、5。则1算了4次,2~5各一次;然后2算三次,3~5算一次;3算两次,4、5算一次;4算一次,5算1次。总共加起来,每个数各算了4次。1~5的和是n(n+1)/2,总共算了(n-1)次,再乘一下就行了。)
难点在于LCP的减……这个地方我们可以直接在height数组上搞,我们可以发现,每一对(i,j)都对应了height数组上的一段区间、甚至是点!(当i,j两个子串rank相连的时候)那么同样的,每一个height数组上的一段区间(点)也对应我们要求的一个LCP。
这样有什么好处呢?原本暴力的做法是枚举i,j,用RMQ算LCP,再减;而现在我们转换成直接算LCP,而不用考虑是谁和谁的LCP(事实上并不会落下任何一个),这样问题就简单多了:利用求LCP的特殊性,我们对于每一个height[i],都能找到一段区间[l,r],使得height[i]=min(height[l]~height[r])。额意思就是
i 是[l,r]这个区间上的最小值。这样LCP=height[i]的对数为 (l-i+1)*(r-i+1)
【ps.我们事先说过,[i,i]这样的一个点也算】也就是说我们的答案里可以减去 2*height[i](l-i+1)(r-i+1)
这样一个值。但是!!这样会有重复计算的情况:
举个栗子:height为 1 2 3 1 2 1 1时,第一个1的[l,r]区间为[1,7],第二个为[1,7],明显有重复计算了([l,i] 和
[i,r]这两段有重叠,也就是计算了两次)所以我们在计算对于每个 i
所能到达的[l,r]区间时,遇到相等元素,必须分开处理:比如如果向右遇到相等元素则可以继续扩展,而向左遇到则停止。(当然你反过来做应该也可以……)现在分析清楚了,最后的问题是:怎么算l[i],r[i],也就是每个height[i]对应的区间?
这里我们可以利用一个叫做单调栈的东西,维护栈里的元素height[j]都比当前的height[j]要小,如果大则弹出,这样就能O(N)求出所有的l[i],r[i]了。
错误:1.计算答案的时候,必须要在乘法中加上 (LL)类型强制转换,否则会出错。
2.在栈为空的时候,意味着左边(右边)所有的元素都比当前的要大,则范围应为从端点(1或n)到i的整个区间,而不是i。
下面贴一下根据自己习惯改的代码(感觉自己还是太弱了,希望自己直接想出并实现类似的题):
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=500009;
typedef long long LL;
int l[maxn],r[maxn];
char a[maxn],b[maxn],s[maxn];
int lena,lenb;
int sa[maxn],rank[maxn],height[maxn];
int wa[maxn],ws[maxn],wv[maxn],wb[maxn];
int st[maxn<<1],top=0,n,m;
bool cmp (int *r,int a,int b,int l) {
return r[a]==r[b] && r[a+l]==r[b+l];
}
void da (char *r,int *sa,int n,int m) {
int *x=wa, *y=wb;
memset(ws,0,sizeof(ws));
for (int i=0;i<n;i++) ws[x[i]=r[i]]++;
for (int i=1;i<m;i++) ws[i]+=ws[i-1];
for (int i=n-1;i>=0;i--) sa[--ws[x[i]]]=i;
int p=1;
for (int j=1;p<n;j<<=1,m=p) {
p=0;
for (int i=n-j;i<n;i++) y[p++]=i;
for (int i=0;i<n;i++) if (sa[i]>=j) y[p++]=sa[i]-j;
memset(ws,0,sizeof(ws));
for (int i=0;i<n;i++) ws[wv[i]=x[y[i]]]++;
for (int i=1;i<m;i++) ws[i]+=ws[i-1];
for (int i=n-1;i>=0;i--) sa[--ws[wv[i]]]=y[i];
swap(x,y); x[sa[0]]=0; p=1;
for (int i=1;i<n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
}
void callheight (char *r,int *sa,int n) {
for (int i=1;i<=n;i++) rank[sa[i]]=i;
for (int i=0,k=0;i<n;height[rank[i++]]=k) {
k?k--:0;
for (int j=sa[rank[i]-1];r[i+k]==r[j+k];k++);
}
}
int main () {
scanf("%s",s);
int n=strlen(s);
for (int i=0;i<n;i++) s[i]=s[i]-'a'+2;
s[n]=0;
da(s,sa,n+1,30);
callheight(s,sa,n);
height[1]=height[n+1]=0;
LL ans=(LL)((LL)n*(n-1)*(n+1))/2,delta=0;
top=0;
st[top++]=1;
for (int i=1;i<=n;i++) {
while (top && height[st[top-1]] > height[i]) top--;
if (top) l[i]=st[top-1]+1; else l[i]=1;
st[top++]=i;
}
top=0;
st[top++]=n;r[n]=n;
for (int i=n;i>=1;i--) {
while (top && height[st[top-1]] >= height[i]) top--;
if (top) r[i]=st[top-1]-1;
else r[i]=n;
st[top++]=i;
}
for (int i=2;i<=n;i++)
delta+=(LL)2*(LL)height[i]*(LL)(i-l[i]+1)*(LL)(r[i]-i+1);
ans-=delta;
printf("%lld\n",ans);
return 0;
}