题意
传送门 Codeforces 452E. Three strings
题解
将子串看做后缀的前缀,那么使用两个互不相同且非小写字母的字符将 3 3 3 个字符串拼接起来,求解后缀数组以及高度数组。由于后缀数组的有序性,故最长相同前缀不小于某个长度 h h h 的后缀集合在后缀数组中处于连续的位置,且对应位置的高度数组的值不小于 h h h。
那么可以使用并查集维护这样连续的位置,从高度数组的最大值( 3 3 3 个字符串的最长长度)开始处理,使用相邻 s a [ i ] , s a [ i + 1 ] sa[i],sa[i+1] sa[i],sa[i+1] 的最长公共前缀进行合并操作,此时已存在于连通分量中的后缀与当前合并的后缀最长公共前缀的长度等于当前枚举的高度;同时统计满足条件的元组数量即可。
#include <bits/stdc++.h>
using namespace std;
#define maxn 300005
const int mod = 1000000007;
int n, k, sa[maxn], lcp[maxn], rnk[maxn], tmp[maxn];
string S, str[3];
bool cmp_sa(int i, int j)
{
if (rnk[i] != rnk[j])
return rnk[i] < rnk[j];
int ri = i + k <= n ? rnk[i + k] : -1;
int rj = j + k <= n ? rnk[j + k] : -1;
return ri < rj;
}
void construct_sa(string &S, int *sa)
{
for (int i = 0; i <= n; ++i)
{
sa[i] = i;
rnk[i] = i < n ? S[i] : -1;
}
for (k = 1; k < n; k <<= 1)
{
sort(sa, sa + n + 1, cmp_sa);
tmp[sa[0]] = 0;
for (int i = 1; i <= n; ++i)
tmp[sa[i]] = tmp[sa[i - 1]] + (cmp_sa(sa[i - 1], sa[i]) ? 1 : 0);
memcpy(rnk, tmp, sizeof(int) * (n + 1));
}
}
void construct_lcp(string &S, int *sa, int *lcp)
{
for (int i = 0; i <= n; ++i)
rnk[sa[i]] = i;
int h = 0;
for (int i = 0; i < n; ++i)
{
int j = sa[rnk[i] - 1];
if (h > 0)
--h;
for (; i + h < n && j + h < n; ++h)
if (S[i + h] != S[j + h])
break;
lcp[rnk[i] - 1] = h;
}
}
typedef long long ll;
typedef pair<int, int> P;
int cnt[maxn][3], len[3], par[maxn], rk[maxn];
ll res[maxn];
vector<P> hs[maxn];
void init_union_find(int n)
{
for (int i = 0; i <= n; ++i)
par[i] = i, rk[i] = 0;
}
int find(int x)
{
return par[x] == x ? x : (par[x] = find(par[x]));
}
ll unite(int x, int y)
{
x = find(x), y = find(y);
if (rk[x] > rk[y])
swap(x, y);
if (rk[x] == rk[y])
++rk[y];
par[x] = y;
ll r = 0;
r -= (ll)cnt[x][0] * cnt[x][1] * cnt[x][2];
r -= (ll)cnt[y][0] * cnt[y][1] * cnt[y][2];
for (int i = 0; i < 3; ++i)
cnt[y][i] += cnt[x][i];
r += (ll)cnt[y][0] * cnt[y][1] * cnt[y][2];
return r;
}
int main()
{
n = 0;
for (int i = 0; i < 3; ++i)
{
cin >> str[i];
len[i] = str[i].length();
for (int j = 0; j < len[i]; ++j)
cnt[n++][i] = 1;
if (i < 2)
++n;
}
S = str[0] + '$' + str[1] + '#' + str[2];
construct_sa(S, sa);
construct_lcp(S, sa, lcp);
for (int i = 0; i < n; ++i)
hs[lcp[i]].push_back(P(sa[i], sa[i + 1]));
int minh = min(len[0], min(len[1], len[2])), maxh = max(len[0], max(len[1], len[2]));
init_union_find(n);
ll sum = 0;
for (int i = maxh; i >= 1; --i)
{
for (int j = 0; j < (int)hs[i].size(); ++j)
{
sum = (sum + unite(hs[i][j].first, hs[i][j].second)) % mod;
}
if (i <= minh)
res[i] = sum;
}
for (int i = 1; i <= minh; ++i)
printf("%lld%c", res[i], i == minh ? '\n' : ' ');
return 0;
}