题目链接
All with Pairs
题意:n个字符串,对每个字符串i 求所有字符串j 的f(si,sj)*f(si,sj) 的和。 f(si,sj)代表 si的最长前缀与 sj的后缀相同。
比如 f( ab aab ) =2 , f(ba,aab) =1
做法:很明显的对所有的后缀hash 保存,然后 枚举 对 每个字符串i的前缀hash ans+=mp[hash]*len*len len 是当前前缀hash 的长度。
当然这样计算是会有重复计算的。
例如 aba 和 aba 会有len=3 len =1 两种情况,而答案只需要保存 len=3 。那么怎么去掉重复的计算呢?
当前找到一个后缀 aba 与 map里面保存的 aba 计算答案的时候 ,就一定会出现 以当前位置 j 为最长后缀(map里的字符串),匹配到一个与 当前字符串i 的 最长前缀。这句话越看越像kmp里的next数组性质。
于是 当样例 “aba” “aba“匹配到len=3 的时候 然后将 位置 next[3] 的字符 标志vis[next[3]]++;当枚举到 前缀 next[3] 位置的时候 将map值减去 vis[next[3]]即可。普通的map超时了,然后用了队友的秘制 hash_map 过了。
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=(b);++i)
#define mem(a,x) memset(a,x,sizeof(a))
#define pb push_back
using namespace std;
typedef long long ll;
ll gcd(ll a,ll b) { return b?gcd(b,a%b):a;}
const int N=1e6+1,M=1e6+1;
ll base[2]={43,47};
ll f[2][N],mod[2]={1000000007,998244353},h[3][N];
void init()
{
for(int j=0;j<=1;++j) f[j][0]=1;
for(int i=1;i<N;++i)
for(int j=0;j<=1;++j) f[j][i]=f[j][i-1]*base[j]%mod[j];
}
ll getv(int l,int r,int j)
{
return (h[j][r]-h[j][l-1]*f[j][r-l+1]%mod[j]+mod[j])%mod[j];
}
const ll modd=998244353;
const ll mood=1e9+7;
const int maxsz=3e6+7;
template<typename key,typename val>
class hash_map{public:
struct node{key u;val v;int next;};
vector<node> e;
int head[maxsz],nume,numk,id[maxsz];
int geths(pair<ll,ll> &u){
int x=(1ll*u.first*mood+u.second)%maxsz;
if(x<0) return x+maxsz;
return x;
}
val& operator[](key u){
int hs=geths(u);
for(int i=head[hs];i;i=e[i].next)if(e[i].u==u) return e[i].v;
if(!head[hs])id[++numk]=hs;
if(++nume>=e.size())e.resize(nume<<1);
return e[nume]=(node){u,0,head[hs]},head[hs]=nume,e[nume].v;
}
void clear(){
for(int i=0;i<=numk;i++) head[id[i]]=0;
numk=nume=0;
}
};
hash_map<pair<ll,ll>,ll> mp;
// unordered_map<pair<ll,ll>,ll>mp;
//map<pair<ll,ll>,ll>mp;
string s[N];
ll vis[N];
int ne[N];
void get(string b) //常规处理方法
{
int len=b.size();
ne[0]=-1;
for(int i=0,j=-1;i<len;)
{
if(j==-1||b[i]==b[j]) ne[++i]=++j;
else j=ne[j];
}
// for(int i=0;i<=len;++i){
// printf("%d ",ne[i]);
// }
// puts("");
}
int n;
int main()
{
std::ios::sync_with_stdio(false);
//get("aaa");
init();
cin>>n;
rep(i,1,n) cin>>s[i];
rep(i,1,n)
{
int len=s[i].size();
for(int j=0;j<len;++j){
int x=s[i][j]-'a'+1;
for(int k=0;k<=1;++k){
h[k][j+1]=(h[k][j]*base[k]%mod[k]+x)%mod[k];
}
}
pair<ll,ll>tmp;
for(int j=1;j<=len;++j){
tmp.first=getv(j,len,0);
tmp.second=getv(j,len,1);
//printf("l:%d r:%d tmp:%lld %lld\n",j,len,tmp.first,tmp.second);
mp[tmp]++;
}
// puts("");
// puts("");
// puts("");
}
//puts("");
//puts("");
//puts("");
ll ans=0,pre;
rep(i,1,n)
{
int len=s[i].size();
for(int j=0;j<len;++j){
int x=s[i][j]-'a'+1;
for(int k=0;k<=1;++k){
h[k][j+1]=(h[k][j]*base[k]%mod[k]+x)%mod[k];
}
}
//puts("");
get(s[i]);
for(int i=0;i<=len;++i) vis[i]=0;
pair<ll,ll>tmp;
pre=0;
for(int j=len;j>=1;--j){
tmp.first=getv(1,j,0);
tmp.second=getv(1,j,1);
// printf("l:%d r:%d tmp:%lld %lld\n",1,j,tmp.first,tmp.second);
// printf("mp:%lld pre:%lld\n\n",mp[tmp],pre);
ans=(ans+(mp[tmp]-vis[j]+modd)%modd*j%modd*j%modd)%modd;
int nx=ne[j];
vis[nx]=(vis[nx]+mp[tmp])%modd;
pre=mp[tmp];
}
}
cout<<ans<<endl;
//printf("%lld\n",ans);
}
/*
3
abc
abc
abc
*/