题目链接:点击这里
题目大意:
给出一个长度为
n
n
n 的字符串
S
S
S 和一个长度为
m
m
m 的字符串
T
T
T 。我们定义
k
k
k 匹配为两个长度相同的字符串至多有
k
k
k 个位置不同,求当
k
∈
[
0
,
m
]
k \in[0,m]
k∈[0,m] 时,
T
T
T 可以和
S
S
S 中的多少个子串匹配
题目分析:
类似于 P4173 残缺的字符串 的思路。
我们发现字符集很小,所以考虑枚举字符集的元素。
如果
S
i
=
T
j
S_i=T_j
Si=Tj ,那么当前字符在
S
S
S 中以
i
+
m
−
j
i+m-j
i+m−j 结尾的长度为
m
m
m 的子串中就产生
1
1
1 的贡献,我们发现
i
+
m
−
j
=
k
i+m-j=k
i+m−j=k ,将
T
T
T 串反转后就是
i
+
j
=
k
i+j=k
i+j=k ,就凑成了卷积的形式,我们卷一下就可以得到一个数组
c
c
c ,
c
[
i
]
c[i]
c[i] 就表示以
T
T
T 在
S
S
S 中匹配的多少位。
而
m
−
c
[
i
]
m-c[i]
m−c[i] 就是还差多少位
T
T
T 就可以与
S
S
S 的子串完全匹配了,因为我们有通配符,所以需要考虑通配符的作用,这时只需要容斥一下:总的通配符个数=
S
S
S 子串中通配符的个数 +
T
T
T 串中通配符的个数 -
S
S
S 子串和
T
T
T 串中通配符匹配的个数。第一项是一个前缀和可以预处理,第二项是一个定值也可以预处理,第三项是一个卷积,卷一下也可以处理出贡献。
我们可以将容斥过程直接操作于
c
c
c 数组上,然后
m
−
c
[
i
]
m-c[i]
m−c[i] 就是
S
[
i
−
m
+
1
,
i
]
S_{[i-m+1,i]}
S[i−m+1,i] 与
T
T
T 的距离,用桶存一下,最后求个前缀就是答案。
时间复杂度位
O
(
k
n
l
o
g
n
)
O(knlogn)
O(knlogn) ,其中
k
k
k 为字符集大小
具体细节见代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<set>
#include<map>
#include<stack>
#include<queue>
#define ll long long
#define inf 0x3f3f3f3f
//#define int ll
using namespace std;
int read()
{
int res = 0,flag = 1;
char ch = getchar();
while(ch<'0' || ch>'9')
{
if(ch == '-') flag = -1;
ch = getchar();
}
while(ch>='0' && ch<='9')
{
res = (res<<3)+(res<<1)+(ch^48);//res*10+ch-'0';
ch = getchar();
}
return res*flag;
}
const int maxn = 4e6+5;
//const int mod = 9973;
const double pi = acos(-1);
const double eps = 1e-8;
int n,m,len = 2,L,rev[maxn],sum[maxn],ans[maxn];
ll a[maxn],b[maxn],c[maxn];
char sa[maxn],sb[maxn];
const int mod = 998244353;//mod原根为3
const int g = 3;//原根
ll qpow(ll a,ll b)
{
ll res = 1;
while(b)
{
if(b&1) res = res*a%mod;
a = a*a%mod;
b >>= 1;
}
return res;
}
void get_rev()
{
for(int i = 0;i < len;i++)
rev[i] = (rev[i>>1]>>1)|(len>>1)*(i&1);
// for(int i = 0;i < len;i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<(L-1));
}
void ntt(ll *a,int dft)
{
for(int i = 0;i < len;i++)
if(i < rev[i]) swap(a[i],a[rev[i]]);//不加这条if会交换两次(就是没交换)
for(int mid = 1;mid < len;mid <<= 1)//mid是准备合并序列的长度的二分之一
{
ll W = qpow(g,(mod-1)/(mid<<1));
if(dft == -1) W = qpow(W,mod-2);
for(int i = 0;i < len;i += mid<<1)//mid*2是准备合并序列的长度,i是合并到了哪一位
{
ll w = 1;
for(int j = i;j < mid+i;j++,w = w*W%mod)//只扫左半部分,得到右半部分的答案
{
int x = a[j];
int y = w*a[j+mid]%mod;
a[j] = (x+y)%mod;
a[j+mid] = ((x-y)%mod+mod)%mod;
}
}
}
if(dft == -1)
{
int tmp = qpow(len,mod-2);
for(int i = 0;i < len;i++) a[i] = (ll)a[i]*tmp%mod;
}
}
signed main()
{
// freopen("1003.in","r",stdin);
// freopen("1.out","w",stdout);
int t = read();
while(t--)
{
n = read(),m = read();
len = 2,L = 1;
while(n+m>=len) len <<= 1,L++;
for(int i = 0;i <= len;i++) c[i] = sum[i] = ans[i] = sa[i] = sb[i] = 0;
get_rev();
scanf("%s%s",sa,sb);
reverse(sb,sb+m);
for(int k = 0;k <= 10;k++)
{
char ch = k==10 ? '*' : '0'+k;
for(int i = 0;i < len;i++)
{
a[i] = sa[i]==ch;
b[i] = sb[i]==ch;
}
ntt(a,1);ntt(b,1);
for(int i = 0;i < len;i++)
{
if(k < 10) c[i] = (c[i]+a[i]*b[i])%mod;
else c[i] = (c[i]-a[i]*b[i]%mod+mod)%mod;
}
}
ntt(c,-1);
int cnt = 0;
for(int i = 0;i < m;i++) cnt += sb[i]=='*';
for(int i = 0;i < n;i++)
{
if(i) sum[i] = sum[i-1];
sum[i] += sa[i]=='*';
}
for(int i = m-1;i < n;i++)
{
c[i] = (c[i]+sum[i]+cnt)%mod;
if(i >= m) c[i] = (c[i]-sum[i-m]+mod)%mod;
}
for(int i = m-1;i < n;i++) ans[m-c[i]]++;
for(int i = 0;i <= m;i++)
{
if(i) ans[i] += ans[i-1];
printf("%d\n",ans[i]);
}
}
return 0;
}