题目
题解思路
子字符串只有40个,如果我们枚举每次枚举20个,再用枚举后的结果进行匹配,时间复杂度显然是够的,用STL内置的哈希表哈希string再匹配,交了一发mle了。
只能用自己的写的哈希来优化了。利用字符串哈希,我们可以O(1)的获取某个连续子串的哈希值,在用A中的小字符串来拼接大串的部分,通过前一半和后一半的组合来得出答案。
这里用dfs显然更优,我们判断可行性后再进入分支,如果用2进制枚举的化需要知道每个小字符串组合的哈希值,这东西也很大的。2^40
我们用dfs判断进入了某段长度后,再试探的进入下一个长度,即又之后的小字符串能否探到大字符串更大的长度。
这里需要注意好多边界,得考虑好。
碰到字符串哈希问题,尽量使用拼接的思想,因为我们以及可以O1获取连续子串值了。
AC代码
#include <bits/stdc++.h>
//#include <unordered_map>
//priority_queue
#define PII pair<int,int>
#define ll long long
using namespace std;
const int INF = 0x3f3f3f3f;
const int mod = 1e9 + 7 ;
const int mi = 2333 ;
const int N = 5000100 ;
long long cnt1[N] ;
long long cnt2[N] ;
long long len[50] ;
long long a[50] ;
long long num[N] ;
int n , m ;
string s ;
void calc( long long p[] , string sk )
{
int sz = sk.size() ;
p[0] = sk[0] - 'a' ;
for (int i = 1 ; i < sz ; i++ )
{
p[i] = ( p[i-1]*mi + sk[i] - 'a' )%mod ;
}
}
long long calc2(string sk )
{
long long ti = sk[0] - 'a' ;
for (int i = 1 ; i < sk.size() ; i++ )
{
ti = ( ti*mi + sk[i] - 'a' )%mod ;
}
return ti ;
}
long long qsm(long long di , long long m )
{
long long res = 1 ;
while (m)
{
if (m&1)
res = res * di %mod ;
m >>= 1 ;
di = di*di%mod ;
}
return res ;
}
long long calchash(long long a[] , int l , int r )
{
if ( l == 0 )
return a[r] ;
long long tmp ;
tmp = ( ( a[r] - a[l-1]*qsm(mi,r-l+1)%mod )%mod + mod )%mod ;
return tmp ;
}
void dfs1(int x , int l )
{
if ( l > 0 )
cnt1[l-1]++;
if ( x > m )
return ;
//cout << x << " " << l << "\n" ;
for (int i = x ; i <= m ; i++ )
{
if (a[i] == calchash(num,l,l+len[i]-1) )
dfs1(i+1,l+len[i]) ;
}
}
void dfs2(int x , int r )
{
cnt2[r+1]++;
if ( x <= m )
return ;
//cout << x << " " << r << "\n" ;
for (int i = x ; i > m ; i-- )
{
if (a[i] == calchash(num,r-len[i]+1,r))
dfs2(i-1,r-len[i]) ;
}
}
int main()
{
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
cin >> n >> s ;
calc(num,s) ;
for (int i = 1 ; i <= n ; i++ )
{
string tp ;
cin >> tp ;
len[i] = tp.size() ;
a[i] = calc2(tp) ;
}
m = n/2 ;
dfs1(1,0) ;
int lz = s.size() ;
dfs2(n,lz-1) ;
long long ans = 0 ;
for (int i = 1 ; i < lz -1 ; i++ )
{
//cout << cnt1[i] << " " << cnt2[i+1] << "\n";
ans += cnt1[i]*cnt2[i+1] ;
}
ans += cnt2[0] + cnt1[lz-1] ;
cout << ans << "\n";
return 0 ;
}