All with Pairs
题意
记
f
(
s
,
t
)
f(s,t)
f(s,t)为最大的
i
i
i使得
s
1...
i
=
t
∣
t
∣
−
i
+
1...
∣
t
∣
s_{1...i} =t_{\left|t\right|-i+1...\left|t\right|}
s1...i=t∣t∣−i+1...∣t∣
给
n
n
n个串
s
1
,
s
2
,
.
.
.
,
s
n
s_1,s_2,...,s_n
s1,s2,...,sn,求
∑
i
=
1
n
∑
j
=
1
n
f
(
s
i
,
s
j
)
2
\displaystyle\sum_{i = 1} ^ n\displaystyle\sum_{j = 1} ^ n f(s_i,s_j)^2
i=1∑nj=1∑nf(si,sj)2
题解
统计所有串每一个后缀出现次数,这个可以用哈希来实现
map<ull, int> mp;
void insert(string &s) {
ull hash = 0, b = 1;//unsigned long long就可以自然溢出
for (int i = s.length() - 1; i >= 0; i--, b *= base) {
hash += b * (s[i] - 'a' + 1);
mp[hash]++;
}
}
对于一个串
s
s
s来说,记
c
n
t
[
i
]
cnt[i]
cnt[i]为所有串中后缀等于
s
1...
i
s_{1...i}
s1...i的数量
那么串
s
s
s的贡献就是
∑
i
=
1
∣
s
∣
i
2
c
n
t
[
i
]
\displaystyle\sum_{i=1}^{|s|}i^2cnt[i]
i=1∑∣s∣i2cnt[i]
这里有一个问题,如果我们按上面
H
a
s
h
Hash
Hash的方法来求得所有串每一个后缀出现次数
会有重复计算,如
a
b
a
aba
aba的后缀有
a
,
b
a
,
a
b
a
a,ba,aba
a,ba,aba, 如果我们当前串
s
s
s能够匹配的最大长度为
3
3
3,即匹配的串为
a
b
a
aba
aba时,
a
a
a这个串也一定是匹配的,所以要对
c
n
t
cnt
cnt进行容斥
容斥只需要从前往后令
c
n
t
[
n
e
x
t
[
i
]
]
=
c
n
t
[
n
e
x
t
[
i
]
]
−
c
n
t
[
i
]
cnt[next[i]] = cnt[next[i]]-cnt[i]
cnt[next[i]]=cnt[next[i]]−cnt[i]即可
因为如果从后往前容斥就要不断跳 n e x t next next数组,前面每一个重复的串都要减去 c n t [ i ] cnt[i] cnt[i]的贡献
但是正向进行就只用减一次,因为 c n t [ n e x t [ i ] ] = c n t [ n e x t [ i ] ] − c n t [ i ] cnt[next[i]] = cnt[next[i]]-cnt[i] cnt[next[i]]=cnt[next[i]]−cnt[i]中, c n t [ i ] cnt[i] cnt[i]本就包含了后面重复的贡献
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int MAX = 1e5 + 10;
const int base = 233;
const int mod = 998244353;
vector<int> getNext(string &s) {
int n = s.length();
vector<int> nxt(n);
for (int i = 1; i < n; i++) {
int j = nxt[i - 1];
while (j > 0 && s[i] != s[j]) j = nxt[j - 1];
if (s[i] == s[j]) j++;
nxt[i] = j;
}
return nxt;
}
map<ull, int> mp;
void insert(string &s) {
ull hash = 0, b = 1;//unsigned long long就可以自然溢出
for (int i = s.length() - 1; i >= 0; i--, b *= base) {
hash += b * (s[i] - 'a' + 1);
mp[hash]++;
}
}
int N;
string s[MAX];
int cnt[MAX * 10];
int main() {
cin >> N;
for (int i = 1; i <= N; i++) {
cin >> s[i];
insert(s[i]);
}
ll ans = 0;
for (int i = 1; i <= N; i++) {
vector<int> nxt = getNext(s[i]);
ull hash = 0;
for (int j = 0; j < s[i].length(); j++) {
hash = hash * base + s[i][j] - 'a' + 1;
cnt[nxt[j]] -= (cnt[j + 1] = mp[hash]);
//我这里j是从0开始的, 但是cnt数组是从1开始的,有点不一样
}
for (int j = 1; j <= s[i].length(); j++)
ans = (ans + 1ll * cnt[j] * j % mod * j % mod) % mod;
}
printf("%lld\n", ans);
return 0;
}