容斥原理+后缀数组。后缀数组上的计数问题总考,这种方法应该引起重视。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef int lint;
const lint maxn = 500000 + 10;
typedef int lint;
struct suffix{
lint c[maxn],sa[maxn],t1[maxn],t2[maxn],m,n,h[maxn],height[maxn],rk[maxn],s[maxn];
void Build_SA(lint len){
m = 0;
lint* x = t1,*y = t2;
n = len;
for( lint i = 0;i < n;i++ ) m = max( m,s[i] );
for( lint i = 0;i <= m;i++ ) c[i] = 0;
for( lint i = 0;i < n;i++ ) c[ x[i] = s[i] ]++;
for( lint i = 1;i <= m;i++ ) c[i] += c[i-1];
for( lint i = n-1;i >= 0;i-- ) {
sa[ --c[ x[i] ] ] = i;
}
for( lint k = 1;k < n;k <<= 1 ){
lint cnt = 0;
for( lint i = n - 1;i >= n-k;i-- ) y[cnt++] = i;
for( lint i = 0;i < n;i++ ) if( sa[i] >= k ) y[cnt++] = sa[i]-k;
for( lint i = 0;i <= m;i++ ) c[i] = 0;
for( lint i = 0;i < n;i++ ) c[ x[i] ]++;
for( lint i = 1;i <= m;i++ ) c[i] += c[i-1];
for( lint i = cnt-1;i >= 0;i-- ) sa[ --c[ x[ y[i] ] ] ] = y[i];
swap( x,y );
lint num = 0;
x[ sa[0] ] = 0;
for( lint i = 1;i < n;i++ ){
if( y[ sa[i-1] ] != y[ sa[i] ] || y[ sa[i-1]+k ] != y[ sa[i]+k ] ){
x[ sa[i] ] = ++num;
}else{
x[ sa[i] ] = num;
}
}
if( num == n-1 ) return;
m = num;
}
}
void getheight(){
for( lint i = 0;i < n;i++ ){
rk[ sa[i] ] = i;
}
lint cnt = 1;
h[ sa[0] ] = 0;
height[ 0 ] = 0;
for( lint i = 0;i < n;i++ ){
if(cnt)cnt--;
while( rk[i] >= 1 && i + cnt < n &&sa[rk[i]-1]+cnt < n && s[ sa[rk[i]-1]+cnt ] == s[ i+cnt ] ) cnt++;
h[ i ] = cnt;
height[ rk[i] ] = cnt;
}
//cout << "debug" << endl;
}
}g;
lint rmq[maxn][20],lg[maxn];
void init( lint s,lint t ){
lint dt = t-s+1;
lg[1] = 0;
for( lint i = 2;i <= dt;i++ ) lg[i] = i&(i-1) ? lg[i-1] : lg[i-1] + 1;
}
void build2( lint s,lint t,lint rmq[][20] ){
lint dt = t-s;
for( lint i = s;i < t;i++ ){
rmq[i][0] = g.height[i];
}
for( lint i = 1;i <= lg[dt];i++ ){
for( lint j = s;j + (1 << i) <= t;j++ ){
rmq[j][i] = min( rmq[j][ i-1 ],rmq[ j + (1 << i-1 ) ][ i-1 ] );
}
}
}
lint lcp( lint x,lint y,lint rmq[][20],lint len ){
if( x == y ) return len - g.sa[x];
x++; y++;
if( x > y ) swap( x,y );
if( x >= len || y > len ) return 0;
lint k = lg[y-x];
return min( rmq[x][k],rmq[ y - ( 1 << k ) ][k] );
}
LL solve( lint n,lint k ){
LL res = 0;
lint t1,t2,t3;
for( lint i = 0;i+k-1 < n;i++ ){
t1 = lcp( i,i+k-1,rmq,n );
t2 = lcp( i,i+k,rmq,n );
t3 = g.height[i];
lint dt = t1-max(t2,t3);
if( dt > 0 ) res += dt;
}
return res;
}
char s[maxn];
int main()
{
lint A,B;
lint T,k;
scanf("%d",&T);
while( T-- ) {
scanf("%d%s",&k,s);
lint n = strlen( s );
for( lint i =0;i < n;i++ ) g.s[i] = s[i];
g.Build_SA(n);
g.getheight();
init( 0,n );
build2( 0,n,rmq );
LL ans = solve(n,k) ;
printf("%lld\n",ans);
}
return 0;
}