给你一个母串 和 一个子串 问母串中有多少个子串更改不超过三个字母 可以和子串相匹配。
做法: 把母串和子串连起来建sa数组 然后 对母串的开头 和子串的开头进行枚举,如果首字母相同则求下一个跳到当前位置+lcp(母串当前位置,子串当前位置) 然后继续往下比较 最多只要跳不超过3次 就可以完成一次 开头的枚举
建sa O(T*nlog(n)) n为母串加子串的长度
暴力 O(T*3*n) n为子串的长度
最终 O(T*nlog(n)) n为母串加子串的长度
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#define ll long long
#define RG register
using namespace std;
const int maxn=210000; //字符总长度加上连接字符的个数
int n;
int s[maxn],t1[maxn],t2[maxn],rk[maxn],c[maxn],height[maxn],sa[maxn],dp[maxn][19];
int cmp(int *r,int a,int b,int k)
{
return r[a]==r[b]&&r[a+k]==r[b+k];
}
void build_sa(int n,int m)
{
int *x=t1,*y=t2;
for(int i=0;i<m;i++)c[i]=0;
for(int i=0;i<n;i++)c[x[i]=s[i]]++;
for(int i=1;i<m;i++)c[i]+=c[i-1];
for(int i=n-1;i>=0;i--)sa[--c[x[i]]]=i;
int p;
for(int k=1;k<=n;k<<=1,m=p)
{
p=0;
for(int i=n-k;i<n;i++)y[p++]=i;
for(int i=0;i<n;i++)if(sa[i]>=k)y[p++]=sa[i]-k;
for(int i=0;i<m;i++)c[i]=0;
for(int i=0;i<n;i++)c[x[y[i]]]++;
for(int i=1;i<m;i++)c[i]+=c[i-1];
for(int i=n-1;i>=0;i--)sa[--c[x[y[i]]]]=y[i];
swap(x,y);
p=1;
x[sa[0]]=0;
for(int i=1;i<n;i++)
{
x[sa[i]]=cmp(y,sa[i],sa[i-1],k)?p-1:p++;
}
if(p>=n)break;
}
}
void getheight(int n)
{
int k=0;
for(int i=0;i<n;i++)rk[sa[i]]=i;
for(int i=0;i<n;i++)
{
if(k)k--;
int j=sa[rk[i]-1];
while(s[i+k]==s[j+k])
{
k++;
}
height[rk[i]]=k;
}
}
void st_build(int n)
{
for(int i = 0; i < n; i++)
{
dp[i][0] = height[i];
}
for(int j = 1; (1<<j) <= n; j++)
for(int i = 0; (i+(1<<j)-1) <= n; i++)
dp[i][j] = min(dp[i][j-1], dp[i+(1<<(j-1))][j-1]);
}
int query(int i, int j)
{
int l = min(rk[i], rk[j]);
int r = max(rk[i], rk[j]);
//if(l==r)return 1e9;
++l;
int cnt = log2(r-l+1), len = 1<<cnt;
return min(dp[l][cnt], dp[r-len+1][cnt]);
}
char S[maxn],S0[maxn];
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
n=0;
scanf("%s",S);
scanf("%s",S0);
int lens=strlen(S);
int lens0=strlen(S0);
for(int i=0;i<lens;i++)
{
s[n++]=S[i]-'A'+1;
}
s[n++]=29;
for(int i=0;i<lens0;i++)
{
s[n++]=S0[i]-'A'+1;
}
s[n]=0;
build_sa(n+1,30);
getheight(n+1);
st_build(n+1);
int ans=0;
for(int i=0;i<=lens-lens0;i++)
{
int t=0;
int j=0;
for(j=0;j<=lens0-1&&t<=3;)
{
if(s[i+j]!=s[lens+1+j])
{
t++;
j++;
}
else j+=query(i+j,lens+1+j);
}
if(t<=3)ans++;
}
printf("%d\n",ans);
}
return 0;
}