如果字符串是给定的,不询问区间,按照论文里写的,由于每个子串一定是某个后缀的前缀,相当于就是求后缀之间不相同前缀的个数。每次新加进来一个后缀,就是加上(n-sa[i]+1)个新前缀,但是其中有height[i]个前面已经算过了,减掉即可。
这题询问的是一个区间[l,r],一开始的思路是,对于每一个sa[i]在l到r之间的后缀,按前面那种方法计算,同时注意考虑一下lcp的右边超过了r的情况,过了样例交就wa了..后来看大神的博客,发现他说他当时错的跟我一样..就是没考虑区间中的sa数组和整个字符串的sa数组不一样的问题(一开始想到了来着..后来不知道为什么忘了..每次都是这样子)解决的办法就是加一句话,判断一下当前的这个后缀和上一个后缀的前后关系在区间中是否是正确的,具体见代码注释那句话
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
const int maxn=4010;
int sa[maxn],t[maxn],t2[maxn],c[maxn],n,s[maxn],mi[maxn][20];
int Rank[maxn],height[maxn];
char str[maxn];
void rmq()
{
int i,j;
for(i=1;i<n-1;i++)mi[i][0]=height[i+1];
for(j=1;(1<<j)<n-1;j++)
for(i=1;i+(1<<j)-1<n-1;i++)
mi[i][j]=min(mi[i][j-1],mi[i+(1<<(j-1))][j-1]);
}
int Q(int l,int r)
{
int k=(int)(log(1.0*(r-l+1))/log(2.0));
return min(mi[l][k],mi[r+1-(1<<k)][k]);
}
void get_sa(int m)
{
int i,*x=t,*y=t2;
for(i=0;i<m;i++) c[i]=0;
for(i=0;i<n;i++) c[x[i]=s[i]]++;
for(i=1;i<m;i++) c[i]+=c[i-1];
for(i=n-1;i>=0;i--) sa[--c[x[i]]]=i;
for(int k=1;k<=n;k<<=1)
{
int p=0;
for(i=n-k;i<n;i++) y[p++]=i;
for(i=0;i<n;i++) if(sa[i]>=k) y[p++]=sa[i]-k;
for(i=0;i<m;i++) c[i]=0;
for(i=0;i<n;i++) c[x[y[i]]]++;
for(i=1;i<m;i++) c[i]+=c[i-1];
for(i=n-1;i>=0;i--) sa[--c[x[y[i]]]]=y[i];
swap(x,y);
p=1;x[sa[0]]=0;
for(i=1;i<n;i++)
x[sa[i]]=y[sa[i-1]]==y[sa[i]] && y[sa[i-1]+k]==y[sa[i]+k]?p-1:p++;
if(p>=n) break;
m=p;
}
}
void get_height()
{
int i,j,k=0;
for(i=1;i<n;i++) Rank[sa[i]]=i;
for(i=0;i<n-1;i++)
{
if(k)k--;
j=sa[Rank[i]-1];
while(s[i+k]==s[j+k]) k++;
height[Rank[i]]=k;
}
}
int solve(int l,int r)
{
int i,flag=0,ans=0,tmp=-1,t,b,sum;
for(i=1;i<n;i++)
{
if(sa[i]<=r && sa[i]>=l)
{
ans+=(r-sa[i]+1);
if(!flag) flag=1;
else
{
t=r-max(sa[i],sa[tmp])+1;
b=Q(tmp,i-1);
sum=b>t?t:b;
ans-=sum;
}
if(tmp==-1)tmp=i;
else
{
if(sa[tmp]<sa[i] && b>=r-sa[i]+1){}//就是少了这句话就wa了..
else tmp=i;
}
}
}
return ans;
}
int main()
{
int t,i,q,l,r;
scanf("%d",&t);
while(t--)
{
scanf("%s",str);
n=strlen(str);
for(i=0;i<n;i++) s[i]=str[i]-'a'+1;
s[n++]=0;
get_sa(30);
get_height();
rmq();
scanf("%d",&q);
while(q--)
{
scanf("%d%d",&l,&r);
l--;r--;
printf("%d\n",solve(l,r));
}
}
return 0;
}