求出
height
数组,移动一个长度为
k
的区间。
对于
k=1
特判一下。
#include<algorithm>
#include<cstdio>
#include<cstring>
using namespace std;
typedef long long ll;
const int N = 1e5+7;
int sa[N],c[N],t1[N],t2[N],rk[N],h[N],dp[N][20],mm[N];
int n;
char s[N];
bool cmp(int *r, int a, int b, int l)
{
return r[a] == r[b] && r[a+l] == r[b+l] ;
}
//最后得到的 rk 在[0,n-1] 有效
//h和sa均在[1,n]内有效
void get_sa(char *s,int *sa,int *h,int *rk,int n,int m)
{
s[n]=0; ++n; //++n的原因是因为最后引入了一个0,因此预处理后显然sa[oldn]=0。
int p, *x = t1, *y = t2;
//基数排序:计算每个值的cnt,然后计算cnt的前缀和,标志每种值的在sa数组中的区间,
//然后将下标(后缀第一个字母的位置)放置到sa数组中。
for(int i=0;i<m;++i) c[i]=0;
for(int i=0;i<n;++i) ++c[x[i]=s[i]];
for(int i=1;i<m;++i) c[i]+=c[i-1];
for(int i=n-1;i>=0;--i) sa[--c[x[i]]]=i;
//k是倍增的大小,新的数组中,(x[i],x[i+k])是一个pair。
//k<=n 因为极限情况是0+n==n。
for(int k=1;k<=n;k<<=1)
{
//sa倍增
//y是pair中第二个关键字排好序后对应的第一个关键字的下标。
//当sa[i]>=k时,才有第一个关键字。
p=0;
for(int i=n-k;i<n;++i) y[p++]=i;
for(int i=0;i<n;++i) if(sa[i]>=k) y[p++]=sa[i]-k;
//基数排序,不过这里计算sa的时候根据y从后往前,因为考虑到第一关键字相同时,第二个关键字也是要升序的。
//x相当于s吧,不过每次被增后x会发生变化。
for(int i=0;i<m;++i) c[i]=0;
for(int i=0;i<n;++i) ++c[x[i]]; //这里和邝斌的不一样
for(int i=1;i<m;++i) c[i]+=c[i-1];
for(int i=n-1;i>=0;--i) sa[--c[x[y[i]]]]=y[i];
swap(x,y);
p=1;x[sa[0]]=0; //sa[0]=n-1,相当于x[n-1]=0。
//给x赋新值,根据sa[i]升序,这样保证p最小。且相同的pair值一样。
for(int i=1;i<n;++i)
x[sa[i]] = cmp(y,sa[i-1],sa[i],k) ? p-1:p++;
if(p>=n) break; //不同的值已经达到n就退出。
m=p; //x中的最大值+1。
}
int k=0,j;
--n;
for(int i=0;i<=n;++i) rk[sa[i]]=i; //rk[n]=0
//for(int i=0;i<n;++i) printf("%d\n",rk[i]);
//h[rk[i]]>=h[rk[i]-1]-1 。
for(int i=0;i<n;++i)
{
if(k) --k;
j=sa[rk[i]-1];
while(s[i+k]==s[j+k]) ++k;
h[rk[i]]=k;
}
}
void init_RMQ()
{
mm[0]=-1;
for(int i=1;i<=n;++i)
{
mm[i]=((i&(i-1))==0)?mm[i-1]+1:mm[i-1];
dp[i][0]=h[i];
}
for(int j=1;j<=mm[n];++j)
for(int i=1;i+(1<<j)-1<=n;++i)
dp[i][j]=min(dp[i][j-1],dp[i+(1<<(j-1))][j-1]);
}
int rmq(int x,int y)
{
if(x>y) swap(x,y);
int k=mm[y-x+1];
return min(dp[x][k],dp[y-(1<<k)+1][k]);
}
int main()
{
int T;
scanf("%d",&T);
while(T--)
{
int k;
scanf("%d",&k);
scanf("%s",s);
n=strlen(s);
get_sa(s,sa,h,rk,n,128);
if(k==1)
{
ll ans=(ll)(1+n)*n/2;
for(int i=2;i<=n;++i)
ans-=h[i]+max(h[i]-h[i-1],0);
printf("%I64d\n",ans);
continue;
}
init_RMQ();
ll ans=0;
for(int i=1;i<=n;++i)
{
if(i+k-1<=n) ans+=rmq(i+1,i+k-1)-rmq(i,i+k-1);
if(i+k<=n) ans+=rmq(i,i+k)-rmq(i+1,i+k);
}
printf("%I64d\n",ans);
}
return 0;
}