题意
定义两字符串相似,当且仅当存在至多一个 i i i,使得这两个字符串中只有第 i i i 个字母不同。取出一个字符串的所有长为 m m m 的子串,问:对于每个字符串,其它长为 m m m 的子串中有多少个与它相似。 n , m ≤ 1 0 5 n,m\leq 10^5 n,m≤105,时限 3 s。
题解
S , T S,T S,T 相似等价于 LCP ( S , T ) + LCS ( S , T ) ≥ m − 1 \operatorname{LCP}(S, T) + \operatorname{LCS}(S, T) ≥ m-1 LCP(S,T)+LCS(S,T)≥m−1。官方题解用的前/后缀树,并且树上启发式合并。然而用 height 数组做与之同理,所以我写的序列上的启发式合并。
首先预处理出字符串和其反串的 SA、height 等信息。对于每个子串 S [ i . . . i + m − 1 ] S[i...i+m-1] S[i...i+m−1],可以把它看成二维点对 ( r k [ i ] , r e v _ r k [ n − ( i − m + 1 ) + 1 ] ) (rk[i],rev\_rk[n-(i-m+1)+1]) (rk[i],rev_rk[n−(i−m+1)+1])。
将它们按 x x x 排序并分治,每次将 ( l , r ] (l,r] (l,r] 中 h e i g h t height height 最小的拿来做 m i d mid mid。开一个树状数组记录启发式合并到了的点的 y y y。显然跨过 m i d mid mid 的两个点(对应的子串)的最长公共前缀为 h e i g h t p [ m i d ] . x height_{p[mid].x} heightp[mid].x。先递归计算两个区间较短的一个并清空树状数组,再计算较长的一个并保留树状数组。
准备一个数组和一个数据结构分别维护长区间对短区间的贡献和短区间对长区间的贡献。然后枚举短区间的每一个点,统计与它最长公共后缀 ≥ m − 1 − h e i g h t p [ m i d ] . x \geq m-1-height_{p[mid].x} ≥m−1−heightp[mid].x 的在树状数组里的点的数量,加进,显然它们是 r e v _ r k rev\_rk rev_rk 连续的一段,二分出这个范围然后在树状数组上查一下,把这个结果加进数组中(长对短的贡献)。并且在数据结构中把长区间中树状数组里为 1 的地方加 1(短对长的贡献)。最后把短区间的每一个数往树状数组里加。
每次分治都只枚举短区间,时间复杂度 O ( n log 2 n ) O(n\log^2 n) O(nlog2n)。
代码:
#include<bits/stdc++.h>
using namespace std;
#define pii pair<int,int>
#define fi first
#define se second
#define mp make_pair
int getint(){
int ans=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){
if(c=='-')f=-1;
c=getchar();
}
while(c>='0'&&c<='9'){
ans=ans*10+c-'0';
c=getchar();
}
return ans;
}
const int N=2e5+10,L=19;
char s[N],rev_s[N];
int n,m;
struct SA{
char *s;
int sa[N],rk[N],tp[N],h[N];
int c[N];
void rsort(int m){
for(int i=1;i<=m;i++)c[i]=0;
for(int i=1;i<=n;i++)c[rk[tp[i]]]++;
for(int i=1;i<=m;i++)c[i]+=c[i-1];
for(int i=n;i>=1;i--)sa[c[rk[tp[i]]]--]=tp[i];
}
bool cmp(int x,int y,int k){
return tp[x]==tp[y]&&tp[x+k]==tp[y+k];
}
void get_sa(){
for(int i=1;i<=n;i++)rk[i]=s[i],tp[i]=i;
int m=128;
rsort(m);
for(int k=1,p=0;p<n;m=p,k<<=1){
p=0;
for(int i=n-k+1;i<=n;i++)tp[++p]=i;
for(int i=1;i<=n;i++)
if(sa[i]>k)tp[++p]=sa[i]-k;
rsort(m);
memcpy(tp,rk,sizeof(tp));
rk[sa[1]]=p=1;
for(int i=2;i<=n;i++)
rk[sa[i]]=cmp(sa[i],sa[i-1],k)?p:++p;
}
int t=0;
for(int i=1;i<=n;i++){
if(t)--t;
while(s[sa[rk[i]-1]+t]==s[i+t])++t;
h[rk[i]]=t;
}
}
int st[L][N],l2[N];
void init_st(){
l2[0]=-1;for(int i=1;i<=n;i++)st[0][i]=h[i+1],l2[i]=l2[i>>1]+1;
for(int i=1;i<L;i++){
for(int j=1;j+(1<<i-1)<=n;j++){
st[i][j]=min(st[i-1][j],st[i-1][j+(1<<i-1)]);
}
}
}
int get_min(int l,int r){
if(l==r)return n-sa[l]+1;
int t=l2[r-l];
return min(st[t][l],st[t][r-(1<<t)]);
}
int getl(int p,int len){
int l=1,r=p-1,mid=0,ans=p;
while(l<=r){
mid=l+r>>1;
if(get_min(mid,p)>=len)ans=mid,r=mid-1;
else l=mid+1;
}
return ans;
}
int getr(int p,int len){
int l=p+1,r=n,mid=0,ans=p;
while(l<=r){
mid=l+r>>1;
if(get_min(p,mid)>=len)ans=mid,l=mid+1;
else r=mid-1;
}
return ans;
}
void init(char *str){
s=str;
get_sa();
init_st();
}
} sa,rev_sa;
pii p[N];
int lowbit(int x){
return x&-x;
}
int a[N];
int ans1[N<<2],sz[N<<2],tag[N<<2],cnt=0;//记录由短区间到长区间的贡献
void pushup(int x){
sz[x]=sz[x<<1]+sz[x<<1|1];
ans1[x]=ans1[x<<1]+ans1[x<<1|1];
}
void pushdown(int x){
tag[x<<1]+=tag[x];tag[x<<1|1]+=tag[x];
ans1[x<<1]+=tag[x]*sz[x<<1];
ans1[x<<1|1]+=tag[x]*sz[x<<1|1];
tag[x]=0;
}
void ins(int pos,int val,int x,int nl,int nr){
if(nl==nr){
sz[x]+=val;
return;
}
pushdown(x);
int mid=nl+nr>>1;
if(pos<=mid)ins(pos,val,x<<1,nl,mid);
else ins(pos,val,x<<1|1,mid+1,nr);
pushup(x);
}
void modify(int l,int r,int val,int x,int nl,int nr){
if(nl>r||nr<l)return;
if(l<=nl&&nr<=r){
tag[x]+=val;
ans1[x]+=val*sz[x];
return;
}
pushdown(x);
int mid=nl+nr>>1;
modify(l,r,val,x<<1,nl,mid);modify(l,r,val,x<<1|1,mid+1,nr);
pushup(x);
}
int query(int l,int r,int x,int nl,int nr){
if(nl>r||nr<l)return 0;
if(l<=nl&&nr<=r){
return ans1[x];
}
pushdown(x);
int mid=nl+nr>>1;
return query(l,r,x<<1,nl,mid)+query(l,r,x<<1|1,mid+1,nr);
}
int ans2[N];//记录由长区间对短区间的贡献
void modify(int *a,int x,int val){
for(;x<=n;x+=lowbit(x))a[x]+=val;
}
int query(int *a,int x){
int ans=0;for(;x;x-=lowbit(x))ans+=a[x];return ans;
}
int stt[L][N],l2[N];
void init_st(){
l2[0]=-1;for(int i=1;i<=n;i++)stt[0][i]=i+1,l2[i]=l2[i>>1]+1;
for(int i=1;i<L;i++)for(int j=1;j<=n-(1<<i-1);++j){
int p=stt[i-1][j],q=stt[i-1][j+(1<<i-1)];
stt[i][j]=sa.h[p]<sa.h[q]?p:q;
}
}
int get_mid(int l,int r){
int t=l2[r-l];
int p=stt[t][l],q=stt[t][r-(1<<t)];
return sa.h[p]<sa.h[q]?p:q;
}
void solve(int l,int r,bool emp=0){
if(l==r&&!emp){
ins(p[l].se,1,1,1,n+1);
modify(a,p[l].se,1);
}
if(l>=r)return;
int mid=get_mid(l,r);
int len=m-1-sa.h[mid];
int l1=mid,r1=r,l2=l,r2=mid-1;
if(r1-l1>r2-l2)swap(l1,l2),swap(r1,r2);
{
solve(l1,r1,1);
solve(l2,r2,0);
for(int i=l1;i<=r1;i++){
if(p[i].se>n)continue;
int ql=rev_sa.getl(p[i].se,len),
qr=rev_sa.getr(p[i].se,len);
int cnt=query(a,qr)-query(a,ql-1);
ans2[i]+=cnt;
modify(ql,qr,1,1,1,n+1);
}
for(int i=l1;i<=r1;i++){
ins(p[i].se,1,1,1,n+1);
modify(a,p[i].se,1);
}
if(emp){
for(int i=l;i<=r;i++){
ins(p[i].se,-1,1,1,n+1);
modify(a,p[i].se,-1);
}
}
}
return;
}
int main(){
int size=40<<20;//40M
#ifndef ONLINE_JUDGE
__asm__ ("movl %0,%%esp\n"::"r"((char*)malloc(size)+size));
#else
__asm__ ("movq %0,%%rsp\n"::"r"((char*)malloc(size)+size));
#endif
n=getint(),m=getint();
scanf("%s",s+1);
for(int i=1;i<=n;i++)rev_s[n-i+1]=s[i];
sa.init(s);
rev_sa.init(rev_s);
init_st();
int o=n-m+1;
for(int i=1;i<=o;i++){
p[i]=mp(sa.rk[i],rev_sa.rk[n-(i+m-1)+1]);
}
for(int i=o+1;i<=n;i++)p[i]=mp(sa.rk[i],n+1);
sort(p+1,p+n+1);
solve(1,n);
for(int i=1;i<=o;i++){
printf("%d ",ans2[sa.rk[i]]+
query(rev_sa.rk[n-(i+m-1)+1],rev_sa.rk[n-(i+m-1)+1],1,1,n+1));
}
exit(0);
}