题目:
http://acm.hdu.edu.cn/showproblem.php?pid=6153
题意:
给出两个字符串 s1,s2 ,求出 s2 的每一个后缀在 s1 中出现的次数乘以这个后缀的长度,并累加求和,输出这个和
思路:
kmp
和
extkmp
都可以做。
extkmp
:求得是
s2
后缀在
s1
中出现的次数,我们将两个字符串翻转,就可以求
s2
前缀在
s1
中出现的次数,等价于没翻转之前求后缀,求前缀可以用
extkmp
算法,可以发现,长度较小的前缀一定包含在长度较大前缀里面,所以求出
extend
数组,记录每个公共前缀长度的出现次数,然后从长度较大的前缀开始循环,把数量一直累加到长度较小的前缀上
kmp:
首先也是翻转两个字符串,然后匹配两个串,当
s2
串位置
j
和
extkmp:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1000000 + 10, mod = 1e9 + 7;
char ori[N], pat[N];
int Next[N], extend[N];
int num[N];
void get_next(char *pat)
{
int len = strlen(pat);
Next[0] = len;
int k = 0;
while(k + 1 < len && pat[k] == pat[k+1]) ++k;
Next[1] = k;
k = 1;
for(int i = 2; pat[i]; i++)
{
if(i + Next[i-k] < k + Next[k]) Next[i] = Next[i-k];
else
{
int j = k + Next[k] - i;
if(j < 0) j = 0;
while(i + j < len && pat[j] == pat[i+j]) ++j;
Next[i] = j;
k = i;
}
}
}
void extkmp(char *ori, char *pat)
{
get_next(pat);
int leno = strlen(ori), lenp = strlen(pat);
int k = 0;
while(k < leno && k < lenp && ori[k] == pat[k]) ++k;
extend[0] = k;
k = 0;
for(int i = 1; ori[i]; i++)
{
if(i + Next[i-k] < k + extend[k]) extend[i] = Next[i-k];
else
{
int j = k + extend[k] - i;
if(j < 0) j = 0;
while(i + j < leno && j < lenp && ori[i+j] == pat[j]) ++j;
extend[i] = j;
k = i;
}
}
}
int main()
{
int t;
scanf("%d", &t);
while(t--)
{
scanf("%s%s", ori, pat);
int leno = strlen(ori), lenp = strlen(pat);
reverse(ori, ori + leno);
reverse(pat, pat + lenp);
extkmp(ori, pat);
memset(num, 0, sizeof num);
for(int i = 0; i < leno; i++) num[extend[i]]++;
ll ans = 0;
for(int i = lenp; i >= 1; i--)
{
num[i] += num[i+1];
ans = (ans + 1LL * num[i] * i) % mod;
}
printf("%lld\n", ans);
}
return 0;
}
kmp:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1000000 + 10, mod = 1e9 + 7;
char ori[N], pat[N];
int Next[N];
int num[N];
void get_next(char *pat)
{
int i = 0, j = -1;
Next[0] = -1;
while(pat[i])
{
if(j == -1 || pat[i] == pat[j]) Next[++i] = ++j;
else j = Next[j];
}
}
void kmp(char *ori, char *pat)
{
get_next(pat);
int i = 0, j = 0;
while(ori[i])
{
if(j == -1 || ori[i] == pat[j]) ++i, ++j;
else j = Next[j];
if(j != -1) num[j]++;
if(j != -1 && ! pat[j]) j = Next[j];
}
}
int main()
{
int t;
scanf("%d", &t);
while(t--)
{
memset(num, 0, sizeof num);
scanf("%s%s", ori, pat);
int leno = strlen(ori), lenp = strlen(pat);
reverse(ori, ori + leno);
reverse(pat, pat + lenp);
kmp(ori, pat);
ll ans = 0;
for(int i = lenp; i >= 1; i--)
{
num[Next[i]] += num[i];
ans = (ans + 1LL * num[i] * i) % mod;
}
printf("%lld\n", ans);
}
return 0;
}