Problem
求字符串 S 中严格出现 k 次的子串个数
k≥1
|S|≤105
∑|S|≤2×106
Idea
貌似很多队都是用后缀树 AC 的。好吧,我不会。
后缀数组 + 线段树 解法:
利用后缀数组处理出 height[]
数组,显然 height[i]
表示 sa[i]
与 sa[i-1]
的最长公共前缀(LCP) 。
利用线段树或者 ST 表存储 height[]
数组,要求能够做的
O(1) or O(logn)
获取区间内最小的 height 。
按序枚举每一个 sa[i]
,显然有效子串的下界为
max(0,heighti,heighti+k)
。上界为
min(height[i+1,i+k−1])
(当 k=1 时,上界为 sa[i]
的长度)。在串长在上下界之间的所有 sa[i]
的前缀子串都对答案恭喜为 1 。
感觉不好写,直接甩手扔给队友… : )
Code
#include<bits/stdc++.h>
using namespace std;
#define LL long long
const int N = 100100;
int cmp(int *r,int a,int b,int l){
return (r[a]==r[b]) && (r[a+l]==r[b+l]);
}
int wa[N],wb[N],wc[N],wv[N];
int Rank[N],height[N];
int mn[N<<2];
void DA(char *r,int *sa,int n,int m){
int i,j,p,*x=wa,*y=wb,*t;
for(i=0;i<m;i++) wc[i]=0;
for(i=0;i<n;i++) wc[x[i]=r[i]]++;
for(i=1;i<m;i++) wc[i]+=wc[i-1];
for(i=n-1;i>=0;i--) sa[--wc[x[i]]]=i;
for(j=1,p=1;p<n;j*=2,m=p)
{
for(p=0,i=n-j;i<n;i++) y[p++]=i;
for(i=0;i<n;i++) if(sa[i]>=j) y[p++]=sa[i]-j;
for(i=0;i<n;i++) wv[i]=x[y[i]];
for(i=0;i<m;i++) wc[i]=0;
for(i=0;i<n;i++) wc[wv[i]]++;
for(i=1;i<m;i++) wc[i]+=wc[i-1];
for(i=n-1;i>=0;i--) sa[--wc[wv[i]]]=y[i];
for(t=x,x=y,y=t,p=1,x[sa[0]]=0,i=1;i<n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
}
void calheight(char *r,int *sa,int n) {
int i,j,k=0;
for(i=1;i<=n;i++) Rank[sa[i]]=i;
for(i=0;i<n; height[Rank[i++]] = k )
for(k?k--:0,j=sa[Rank[i]-1]; r[i+k]==r[j+k]; k++);
}
void build(int l,int r,int rt){
if(l==r){
mn[rt]=height[l];
return;
}
int mid=(l+r)>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
mn[rt]=min(mn[rt<<1],mn[rt<<1|1]);
}
int query(int l,int r,int rt,int L,int R){
if(L<=l && r<=R)
return mn[rt];
int mid=(l+r)>>1;
int ret=1e9;
if(mid>=L)
ret=min(ret,query(l,mid,rt<<1,L,R));
if(mid<R)
ret=min(ret,query(mid+1,r,rt<<1|1,L,R));
return ret;
}
char str[N];
int sa[N], nxt[N], T, k;
int main(){
scanf("%d", &T);
while(T--) {
scanf("%d %s",&k,str);
int n = strlen(str);
str[n]=0;
DA(str,sa,n+1,128);
calheight(str,sa,n);
build(1,n,1);
LL ans=0;
for(int i=1;i+k-1<=n;++i) {
int len1;
if(k>1)
len1=query(1,n,1,i+1,i+k-1);
else
len1=n-sa[i];
int len2=0;
if(i!=1)
len2=max(len2,height[i]);
if(i+k-1!=n)
len2=max(len2,height[i+k]);
if(len1>len2)
ans+=len1-len2;
}
printf("%lld\n", ans);
}
}