那么为什么要减去前一个的height呢,因为我们对于同一个串,我们只算一遍,所以要减去之前算过的
#include<cstdio>
#include<algorithm>
#define maxn 100100
using namespace std;
char ch;
int ans,tot,n,ws[maxn],x[maxn],y[maxn],wv[maxn],a[maxn],h[maxn],sa[maxn],t[maxn];
bool get(int x){
if (x >= 'a' && x <= 'z') return true;
if (x >= 'A' && x <= 'Z') return true;
if (x == ',' || x == '.' || x == '!') return true;
return false;
}
void init(){
while (ch = getchar(),(! get(ch)));
a[++ tot] = ch;
while (ch = getchar(),(get(ch))) a[++ tot] = ch;
n = tot;
}
bool cmp(int *rank,int a,int b,int l){
return (rank[a] == rank[b] && rank[a+l] == rank[b+l]);
}
void da(int *x,int *y,int *t){
int i,j,p,m = 122;
for (i = 1 ; i <= m ; i ++) ws[i] = 0;
for (i = 1 ; i <= n ; i ++) ws[x[i] = a[i]] ++;
for (i = 2 ; i <= m ; i ++) ws[i] += ws[i - 1];
for (i = n ; i >= 1 ; i --) sa[ws[x[i]] --] = i;
for (j = 1,p = 1 ; p <= n ; j *= 2,m = p){
for (p = 0,i = n - j + 1 ; i <= n ; i ++) y[++ p] = i;//!!
for (i = 1 ; i <= n ; i ++) if (sa[i] > j) y[++ p] = sa[i] - j;//!!
for (i = 1 ; i <= n ; i ++) wv[i] = x[y[i]];
for (i = 1 ; i <= m ; i ++) ws[i] = 0;
for (i = 1 ; i <= n ; i ++) ws[wv[i]] ++;
for (i = 1 ; i <= m ; i ++) ws[i] += ws[i - 1];
for (i = n ; i >= 1 ; i --) sa[ws[wv[i]] --] = y[i];
for (t = x,x = y,y = t,p = 2,x[sa[1]] = 1,i = 2 ; i <= n ; i ++)
x[sa[i]] = cmp(y , sa[i - 1] , sa[i] , j) ? p - 1 : p ++;
}
}
void calc(){
int k = 0,j,i;
for (i = 1 ; i <= n ; i ++) x[sa[i]] = i;
for (i = 1 ; i <= n ; h[x[i ++]] = k)
for (k ? k -- : 0,j = sa[x[i] - 1] ; a[i + k] == a[j + k] ; k ++);
}
void work(){
da(x,y,t);
calc();
for (int i = 1 ; i <= n ; i++) ans += max(h[i] - h[i - 1] , 0);
}
int main(){
init();
work();
printf("%d",ans);
}