4566: [Haoi2016]找相同字符
Time Limit: 20 Sec Memory Limit: 256 MBSubmit: 640 Solved: 350
[ Submit][ Status][ Discuss]
Description
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。
Input
两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母
Output
输出一个整数表示答案
Sample Input
aabb
bbaa
bbaa
Sample Output
10
HINT
Source
#include<cstdio>
#include<cmath>
#include<queue>
#include<stack>
#include<vector>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long LL;
const int maxn = 1200010;
const int maxs = 200100;
char s[maxn];
int n;
int fa[maxn],ch[maxn][30],Max[maxn],du[maxn],Q[maxn],rt[3],tot,last;
bool vis[maxn];
LL cnt[maxn],ans;
inline LL getint()
{
LL ret = 0,f = 1;
char c = getchar();
while (c < '0' || c > '9')
{
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
ret = ret * 10 + c - '0',c = getchar();
return ret * f;
}
inline void insert(int x,int d)
{
int v = last;
Max[++tot] = Max[v] + 1; last = tot; fa[last] = rt[d]; cnt[last] = 1;
while (v && !ch[v][x]) ch[v][x] = last , v = fa[v];
if (!v) {du[rt[d]]++; return;}
int p = ch[v][x];
if (Max[p] != Max[v] + 1)
{
int np = ++tot;
Max[np] = Max[v] + 1; fa[np] = fa[p]; fa[p] = np; fa[last] = np; du[np] = 2;
while (v && ch[v][x] == p) ch[v][x] = np , v = fa[v];
for (int i = 1; i <= 26; i++) ch[np][i] = ch[p][i];
}
else fa[last] = p , du[p]++;
}
inline void top()
{
int head = 0,tail = 0;
for (int i = 1; i <= tot; i++)
if (!du[i]) Q[++tail] = i;
while (head < tail)
{
int u = Q[++head];
cnt[fa[u]] += cnt[u];
--du[fa[u]];
if (!du[fa[u]])
Q[++tail] = fa[u];
}
int test;
test = 1;
}
inline void dfs(int u,int v)
{
ans += cnt[u] * cnt[v];
for (int i = 1; i <= 26; i++)
{
if (!ch[u][i] || !ch[v][i]) continue;
dfs(ch[u][i],ch[v][i]);
}
}
int main()
{
rt[1] = last = ++tot;
scanf("%s",s + 1); n = strlen(s + 1);
for (int i = 1; i <= n; i++)
insert(s[i] - 'a' + 1,1);
rt[2] = last = ++tot;
scanf("%s",s + 1); n = strlen(s + 1);
for (int i = 1; i <= n; i++) insert(s[i] - 'a' + 1,2);
top();
cnt[0] = cnt[rt[1]] = cnt[rt[2]] = 0;
dfs(rt[1],rt[2]);
printf("%lld",ans);
return 0;
}