洛谷传送门
BZOJ传送门
题目描述
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。
输入输出格式
输入格式:
两行,两个字符串 s 1 s_1 s1, s 2 s_2 s2,长度分别为 n 1 n_1 n1, n 2 n_2 n2。 1 ≤ n 1 , n 2 ≤ 200000 1 \le n_1, n_2\le 200000 1≤n1,n2≤200000,字符串中只有小写字母
输出格式:
输出一个整数表示答案
输入输出样例
输入样例#1:
aabb
bbaa
输出样例#1:
10
解题分析
这道题大概有两种做法。
第一种是在后缀自动机中插入第一个串后插入一个无法匹配的字符, 再插入第二个串, 并同时记录两次插入每个状态的 r i g h t right right集合大小。 插入无法匹配的字符的原因是不能让中间连起来形成合法的子串。
代码如下:
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <cctype>
#include <algorithm>
#include <cstdio>
#define R register
#define IN inline
#define W while
#define MX 2000500
char dat[MX];
int to[MX][27], par[MX], len[MX], siz[2][MX], buc[MX], ind[MX];
int l, cnt, cur, last;
long long ans;
namespace SAM
{
IN void insert(R int id, R int typ)
{
R int now = last, tar;
cur = ++cnt; len[cur] = len[last] + 1; siz[typ][cur] = 1;
for (; (~now) && !to[now][id]; now = par[now]) to[now][id] = cur;
if(now < 0) return par[last = cur] = 0, void();
tar = to[now][id];
if(len[tar] == len[now] + 1) return par[last = cur] = tar, void();
int nw = ++cnt; len[nw] = len[now] + 1;
par[nw] = par[tar], par[tar] = par[last = cur] = nw;
std::memcpy(to[nw], to[tar], sizeof(to[nw]));
for (; (~now) && to[now][id] == tar; now = par[now]) to[now][id] = nw;
}
IN void calc()
{
for (R int i = 1; i <= cnt; ++i) buc[len[i]]++;
for (R int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
for (R int i = 1; i <= cnt; ++i) ind[buc[len[i]]--] = i;
for (R int i = cnt; i; --i)
if(~par[i]) ans += 1ll * (len[ind[i]] - len[par[ind[i]]]) * siz[0][ind[i]] * siz[1][ind[i]],
siz[0][par[ind[i]]] += siz[0][ind[i]], siz[1][par[ind[i]]] += siz[1][ind[i]];
}
}
int main(void)
{
par[0] = -1;
scanf("%s", dat + 1); l = std::strlen(dat + 1);
for (R int i = 1; i <= l; ++i) SAM::insert(dat[i] - 'a', 0);
SAM::insert(26, 0);
scanf("%s", dat + 1); l = std::strlen(dat + 1);
for (R int i = 1; i <= l; ++i) SAM::insert(dat[i] - 'a', 1);
SAM::calc(); printf("%lld\n", ans);
}
另一种是用广义后缀自动机。 我们在后缀自动机中插入第一个串后将 l a s t last last设为初始值, 再插入第二个串。 因为现在每个点的意义是两个串的公共子串, 所以统计时注意 l e n len len要满足第二个串的要求, 及时分裂节点。
代码如下:
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <cstdlib>
#define R register
#define IN inline
#define W while
#define MX 1000500
int par[MX], to[MX][26], len[MX], siz[2][MX], buc[MX], ind[MX];
int cnt, l, last, cur;
char dat[MX];
long long ans;
namespace SAM
{
IN void insert(R int ln, R int id, R int typ)
{
R int now = last, tar, sig = 0;
if((!to[now][id]) || (len[to[now][id]] != len[now] + 1)) cur = ++cnt;
else cur = to[now][id], sig = 1;
++siz[typ][cur], last = cur;
if(sig) return;
len[cur] = ln;
for (; (~now) && !to[now][id]; now = par[now]) to[now][id] = cur;
if (now < 0) return par[cur] = 0, void();
tar = to[now][id];
if(len[tar] == len[now] + 1) return par[cur] = tar, void();
int nw = ++cnt; len[nw] = len[now] + 1;
par[nw] = par[tar], par[tar] = par[cur] = nw;
std::memcpy(to[nw], to[tar], sizeof(to[nw]));
for (; (~now) && to[now][id] == tar; now = par[now]) to[now][id] = nw;
}
}
int main(void)
{
par[0] = -1;
scanf("%s", dat + 1); l = std::strlen(dat + 1);
for (R int i = 1; i <= l; ++i) SAM::insert(i, dat[i] - 'a', 0);
scanf("%s", dat + 1); l = std::strlen(dat + 1); last = 0;
for (R int i = 1; i <= l; ++i) SAM::insert(i, dat[i] - 'a', 1);
for (R int i = 1; i <= cnt; ++i) buc[len[i]]++;
for (R int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
for (R int i = 1; i <= cnt; ++i) ind[buc[len[i]]--] = i;
for (R int i = cnt; i; --i)
ans += 1ll * siz[0][ind[i]] * siz[1][ind[i]] * (len[ind[i]] - len[par[ind[i]]]),
siz[1][par[ind[i]]] += siz[1][ind[i]], siz[0][par[ind[i]]] += siz[0][ind[i]];
printf("%lld", ans);
}