题目链接:点击进入
题目:
思路1
通过题目可以知道,满足条件的前缀对可以由 两个字符串的公共前缀 与 字符串 t 的一个后缀与模式串 s 的公共前缀 组成 ,也就可以转换成求字符串 t 的extend数组的问题,最终的前缀对总和,可以通过枚举 s和 t 的所有相同前缀 si 然后将 ti 对应的 extend 数组计入答案即可。
代码1
#include<iostream>
#include<string>
#include<map>
#include<set>
//#include<unordered_map>
#include<queue>
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
#include<iomanip>
#include<cmath>
#include<fstream>
#define X first
#define Y second
#define base 131
#define INF 0x3f3f3f3f3f3f3f3f
#define pii pair<int,int>
#define lowbit(x) x & -x
#define inf 0x3f3f3f3f
#define int long long
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const double eps=1e-7;
const double pai=acos(-1.0);
const int N=2e4+10;
const int maxn=1e6+10;
const int mod=1e9+7;
int n,m,k,cnt,tot,l1,l2,ans;
int Next[maxn],ext[maxn];
string s,t;
//扩展KMP是用来求主串每个点可以向后延伸与模式串匹配的最长的长度
void getnext()
{
int f=0,l=0;
Next[0]=l1;
for(int i=1;i<l1;i++)
{
Next[i]=min(l-i+1,Next[i-f]);
if(Next[i]<0) Next[i]=0;
while(i+Next[i]<l1&&s[Next[i]]==s[i+Next[i]])
Next[i]++;
if(i+Next[i]-1>l)
{
l=i+Next[i]-1;
f=i;
}
}
}
void exkmp()
{
getnext();
int f=0,l=-1,minn=min(l1,l2);
while(l<minn-1&&t[l+1]==s[l+1])
l++;
ext[0]=l+1;
for(int i=1;i<l2;i++)
{
ext[i]=min(l-i+1,Next[i-f]);
if(ext[i]<0) ext[i]=0;
while(ext[i]<l1&&i+ext[i]<l2&&s[ext[i]]==t[i+ext[i]])
ext[i]++;
if(i+ext[i]-1>l)
{
l=i+ext[i]-1;
f=i;
}
}
}
signed main()
{
// ios::sync_with_stdio(false);
// cin.tie(0);cout.tie(0);
cin>>s;cin>>t;
l1=s.size(),l2=t.size();
exkmp();
for(int i=1;i<=l2;i++)
{
if(s[i-1]!=t[i-1]) break;
if(t[i]!=s[0]) continue;
ans+=ext[i];
}
cout<<ans<<endl;
return 0;
}
思路2
枚举 s和 t 的所有相同前缀 si ,然后二分求 t-si 中能匹配上 s 的前缀sj 的 最大长度,结果累加就是答案
( 哈希能够O(1)判断两个子串是否匹配,又因为某个前缀匹配时,前缀的前缀也一定匹配,所以可以二分找到最长前缀 )
代码2
#include<iostream>
#include<string>
#include<map>
#include<set>
//#include<unordered_map>
#include<queue>
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
#include<iomanip>
#include<cmath>
#include<fstream>
#define X first
#define Y second
#define base 131
#define INF 0x3f3f3f3f3f3f3f3f
#define pii pair<int,int>
#define lowbit(x) x & -x
#define inf 0x3f3f3f3f
//#define int long long
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const double eps=1e-7;
const double pai=acos(-1.0);
const int N=2e4+10;
const int maxn=1e6+10;
const int mod=1e9+7;
int n,m,k,cnt,tot,l1,l2;
ull h1[maxn],h2[maxn],p[maxn],ans;
string s,t;
void hash_init()
{
p[0]=1;h1[0]=0;h2[0]=0;
for(int i=1;i<=max(l1,l2);i++)
p[i]=p[i-1]*base;
for(int i=1;i<=l1;i++)
h1[i]=h1[i-1]*base+(s[i-1]-'0');
for(int i=1;i<=l2;i++)
h2[i]=h2[i-1]*base+(t[i-1]-'0');
}
ull hash_code(int l, int r,int op)
{
if(op==1)
return h1[r]-h1[l-1]*p[r-l+1];
else
return h2[r]-h2[l-1]*p[r-l+1];
}
int main()
{
// ios::sync_with_stdio(false);
// cin.tie(0);cout.tie(0);
cin>>s;cin>>t;
l1=s.size(),l2=t.size();
hash_init();
for(int i=1;i<=l2;i++)
{
if(s[i-1]!=t[i-1]) break;
if(t[i]!=s[0]) continue;
int l=1,r=l1,res=0;
while(l<=r)
{
int mid=l+r>>1;
ull t1=hash_code(1,mid,1);
ull t2=hash_code(i+1,i+mid,2);
if(t1==t2)
l=mid+1,res=mid;
else
r=mid-1;
}
ans+=res;
}
cout<<ans<<endl;
return 0;
}