KMP / 字符串hash
前置
字符串hash
- 记录一字符串所有前缀出现次数
#define base 131
#define ull unsigned long long
unordered_map<ull,int> mp;
string a;
void hash(char a[]){
int len=strlen(a);
ull t=0;
for(int i=0;i<len;i++){
t=t*base+a[i]-'a'+1;
mp[t]++;
}
}
- 记录一字符串所有后缀出现次数
#define base 131
#define ull unsigned long long
unordered_map<ull,int> mp;
string a;
void hash(char a[]){
int len=strlen(a);
ull t=0,p=1;
for(int i=len-1;i>=0;i--){
t=t+p*(a[i]-'a'+1);
p*=base;
mp[t]++;
}
}
KMP
int ne[maxn];
void getne(char a[]){
ne[0]=0;
int len=strlen(a);
for(int i=1,j=0;i<len;i++){
while(a[j]!=a[i]&&j) j=ne[j-1];
if(a[i]==a[j]) j++;
ne[i]=j;
}
}
题意
给定n个字符串
s
1
,
s
2
.
.
.
s
n
s_1,s_2...s_n
s1,s2...sn,求
∑
i
=
1
n
∑
j
=
1
n
f
(
s
i
,
s
j
)
2
(
m
o
d
998244353
)
\sum\limits_{i=1}^n\sum\limits_{j=1}^nf(s_i,s_j)^2(mod998244353)
i=1∑nj=1∑nf(si,sj)2(mod998244353)
f
(
s
i
,
s
j
)
f(s_i,s_j)
f(si,sj)表示
s
i
s_i
si的前缀与
s
j
s_j
sj的后缀相同时的最长长度
分析
- 考虑朴素算法
枚举每个串的前缀,用map找到记录的相应后缀的数量,则 a n s + = 前 缀 长 度 2 ∗ 数 量 ans+=前缀长度^2*数量 ans+=前缀长度2∗数量 - 去重
显然用前缀去找后缀时,同一个字符串可能会被多次找到,使得答案偏大,因此考虑去重 - 性质
令 s i = x 1 x 2 x 3 , s j = y 1 y 2 y 3 s_i=x_1x_2x_3,s_j=y_1y_2y_3 si=x1x2x3,sj=y1y2y3(抽象为前缀、中间、后缀)
设 x 1 = y 3 x_1=y_3 x1=y3,即此时 s i s_i si第一次找到 s j s_j sj
设 x 1 x 2 x 3 = y 1 y 2 y 3 x_1x_2x_3=y_1y_2y_3 x1x2x3=y1y2y3,此时第二次找到
联立式子得 x 1 = y 1 = x 3 = y 3 , x 2 = y 2 x_1=y_1=x_3=y_3,x_2=y_2 x1=y1=x3=y3,x2=y2
则可知 s i = x 1 x 2 x 1 , s j = y 1 y 2 y 1 s_i=x_1x_2x_1,s_j=y_1y_2y_1 si=x1x2x1,sj=y1y2y1
因此可得出结论:若 s i s_i si通过 x 1 x 2 x 1 x_1x_2x_1 x1x2x1的形式找到 s j s_j sj, s i s_i si必能通过 x 1 x_1 x1找到 s j s_j sj - 转化
令k表示 x 1 x_1 x1找到的数量
则据分析,正解应为 a n s + = 前 缀 长 度 2 ∗ ( k − 数 量 ) ans+=前缀长度^2*(k-数量) ans+=前缀长度2∗(k−数量)
即本次枚举的最长相同前后缀对应的前缀能找到的数量减去本次枚举的数量
考虑KMP,求next数组
Code
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e6+5;
const int mod=998244353;
const long long inf=1e18;
const int base=131;
const double pi=3.1415926;
#define ll long long
#define int long long
#define ull unsigned long long
#define maxx(a,b) (a>b?a:b)
#define minx(a,b) (a<b?a:b)
#define IOS ios::sync_with_stdio(false); cin.tie(0); cout.tie(0)
#define debug(...) fprintf(stderr, __VA_ARGS__)
inline ll qpow(ll base, ll n) { assert(n >= 0); ll res = 1; while (n) { if (n & 1) res = res * base % mod; base = base * base % mod; n >>= 1; } return res; }
ll gcd(ll a,ll b) {return b==0?a:gcd(b,a%b);}
ll lcm(ll a,ll b) { return a*b/gcd(a,b); }
ll inv(ll a) {return a == 1 ? 1 : (ll)(mod - mod / a) * inv(mod % a) % mod;}
ll C(ll n,ll m){if (m>n) return 0;ll ans = 1;for (int i = 1; i <= m; ++i) ans=ans*inv(i)%mod*(n-i+1)%mod;return ans%mod;}
ll A(ll n,ll m){ll sum=1; for(int i=n;i>=n-m+1;i--) sum=(sum*i)%mod; return sum%mod;}
ll GetSum(ll L, ll R) {return (R - L + 1ll) * (L + R) / 2ll;} //等差数列求和
/************/
int n,ans,cnt[maxn],ne[maxn];
string a[maxn];
unordered_map<ull,int> mp;
void getne(string a){
int len=a.size();
for(int i=1,j=0;i<len;i++){
while(a[i]!=a[j]&&j) j=ne[j-1];
if(a[i]==a[j]) j++;
ne[i]=j;
}
}
//统计后缀
void hash_suffix(string a){
int len=a.size();
ull t=0,p=1;
for(int i=len-1;i>=0;i--){
t=t+p*(a[i]-'a'+1);
p*=base;
mp[t]++;
}
}
void sol(){
for(int i=0;i<n;i++){
int len=a[i].size();
ull t=0;
getne(a[i]);
for(int j=0;j<len;j++){
t=t*base+a[i][j]-'a'+1;
cnt[j]=mp[t];
}
//去重
for(int j=0;j<len;j++){
int k=ne[j];
if(k) cnt[k-1]-=cnt[j];
}
for(int j=0;j<len;j++){
ans=(ans+cnt[j]%mod*(j+1)%mod*(j+1))%mod;
}
}
}
signed main()
{
IOS;
cin>>n;
for(int i=0;i<n;i++){
cin>>a[i];
hash_suffix(a[i]);
}
sol();
cout<<ans;
return 0;
}