题目大意:
就是现在给出三个总长度不超过3*10^5的字符串, 每个字符串只包含字母'a' ~ 'z', 现在对于每一个L, (1 <= L <= minLength(s1, s2, s3))也就是L不超过s1, s2, s3中最短长度, 求出存在多少个i, j, k使得s1[ i ~ i + L - 1] == s2[ j ~ j + L - 1] == s3[ k ~ k + L - 1], 结果对于10^9 + 7取模之后输出
大致思路:
首先不难想到后缀数组处理三个串拼接起来的总串, 记录每一个字符的来源, 也就是记录每个后缀的来源, 然后需要根据height数组从大到小来利用并查集标记区间进行计算, 注意两个区间合并的时候之后 (i, j, k)三者不来自同一个原来的区间才能算, 所以稍微容斥一下即可
由于事先对height数组排序了, 所以也不需要树状数组之类的来辅助更新答案, 直接利用排序好的height数组的单调性即可
之前想过一个从height由小到大切割区间进行分治dfs的方法, 然后利用树状数组辅助更新答案, 但是复杂度还是太高了...果然还是需要用并查集
代码如下:
Result : Accepted Memory : 27500 KB Time : 202 ms
/*
* Author: Gatevin
* Created Time: 2015/3/18 16:10:35
* File Name: Kotori_Itsuka.cpp
*/
#include<iostream>
#include<sstream>
#include<fstream>
#include<vector>
#include<list>
#include<deque>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<bitset>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cmath>
#include<ctime>
#include<iomanip>
using namespace std;
const double eps(1e-8);
typedef long long lint;
#define maxn 300010
int wa[maxn], wb[maxn], wv[maxn], Ws[maxn];
int cmp(int *r, int a, int b, int l)
{
return r[a] == r[b] && r[a + l] == r[b + l];
}
void da(int *r, int *sa, int n, int m)
{
int *x = wa, *y = wb, *t, i, j, p;
for(i = 0; i < m; i++) Ws[i] = 0;
for(i = 0; i < n; i++) Ws[x[i] = r[i]]++;
for(i = 1; i < m; i++) Ws[i] += Ws[i - 1];
for(i = n - 1; i >= 0; i--) sa[--Ws[x[i]]] = i;
for(j = 1, p = 1; p < n; j *= 2, m = p)
{
for(p = 0, i = n - j; i < n; i++) y[p++] = i;
for(i = 0; i < n; i++) if(sa[i] >= j) y[p++] = sa[i] - j;
for(i = 0; i < n; i++) wv[i] = x[y[i]];
for(i = 0; i < m; i++) Ws[i] = 0;
for(i = 0; i < n; i++) Ws[wv[i]]++;
for(i = 1; i < m; i++) Ws[i] += Ws[i - 1];
for(i = n - 1; i >= 0; i--) sa[--Ws[wv[i]]] = y[i];
for(t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n; i++)
x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p++;
}
return;
}
int rank[maxn], height[maxn];
void calheight(int *r, int *sa, int n)
{
int i, j, k = 0;
for(i = 1; i <= n; i++) rank[sa[i]] = i;
for(i = 0; i < n; height[rank[i++]] = k)
for(k ? k-- : 0, j = sa[rank[i] - 1]; r[i + k] == r[j + k]; k++);
return;
}
int f[maxn];
int find(int x)
{
return x == f[x] ? x : f[x] = find(f[x]);
}
bool cmp2(int a, int b)
{
return height[a] > height[b];
}
char in[maxn];
int s[maxn], sa[maxn], p[maxn], belong[maxn], N;
lint cnt[maxn][3], ans[maxn];
const lint mod = 1e9 + 7;
int main()
{
int mlen = 1e9;
N = 0;
for(int i = 0; i < 3; i++)
{
scanf("%s", in);
int len = strlen(in);
mlen = min(len, mlen);
for(int j = 0; j < len; j++)
{
belong[N] = i;
s[N++] = in[j] - 'a' + 1;
}
belong[N] = -1;
s[N++] = 27 + i;
}
N--;
s[N] = 0;
da(s, sa, N + 1, 30);
calheight(s, sa, N);
for(int i = 0; i <= N; i++) p[i] = f[i] = i;
for(int i = 0; i <= N; i++)
if(belong[i] != -1) cnt[i][belong[i]]++;
sort(p + 1, p + N + 1, cmp2);
lint result = 0;
for(int i = 1; i <= N; i++)
{
if(i > 1 && height[p[i]] != height[p[i - 1]])
for(int j = height[p[i]] + 1; j <= height[p[i - 1]]; j++)
ans[j] = result;
int bl = find(sa[p[i]]), br = find(sa[p[i] - 1]);
result = (result - cnt[bl][0]*cnt[bl][1]*cnt[bl][2] % mod + mod) % mod;
result = (result - cnt[br][0]*cnt[br][1]*cnt[br][2] % mod + mod) % mod;
for(int j = 0; j < 3; j++)
cnt[bl][j] = (cnt[bl][j] + cnt[br][j]) % mod;
f[br] = bl;
result = (result + cnt[bl][0]*cnt[bl][1]*cnt[bl][2]) % mod;
}
for(int i = 1; i <= mlen; i++)
printf("%I64d ", ans[i]);
return 0;
}
顺带祭奠一下以前写的TLE了的方法...
Result : Time Limit Exceeded on test 42
/*
* Author: Gatevin
* Created Time: 2015/3/12 22:32:12
* File Name: Kotori_Itsuka.cpp
*/
#include<iostream>
#include<sstream>
#include<fstream>
#include<vector>
#include<list>
#include<deque>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<bitset>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cmath>
#include<ctime>
#include<iomanip>
using namespace std;
const double eps(1e-8);
typedef long long lint;
const lint mod = 1000000007LL;
#define maxn 300100
int wa[maxn], wb[maxn], wv[maxn], Ws[maxn];
int cmp(int *r, int a, int b, int l)
{
return r[a] == r[b] && r[a + l] == r[b + l];
}
void da(int *r, int *sa, int n, int m)
{
int *x = wa, *y = wb, *t, i, j, p;
for(i = 0; i < m; i++) Ws[i] = 0;
for(i = 0; i < n; i++) Ws[x[i] = r[i]]++;
for(i = 1; i < m; i++) Ws[i] += Ws[i - 1];
for(i = n - 1; i >= 0; i--) sa[--Ws[x[i]]] = i;
for(j = 1, p = 1; p < n; j *= 2, m = p)
{
for(p = 0, i = n - j; i < n; i++) y[p++] = i;
for(i = 0; i < n; i++) if(sa[i] >= j) y[p++] = sa[i] - j;
for(i = 0; i < n; i++) wv[i] = x[y[i]];
for(i = 0; i < m; i++) Ws[i] = 0;
for(i = 0; i < n; i++) Ws[wv[i]]++;
for(i = 1; i < m; i++) Ws[i] += Ws[i - 1];
for(i = n - 1; i >= 0; i--) sa[--Ws[wv[i]]] = y[i];
for(t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n; i++)
x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p++;
}
return;
}
int rank[maxn], height[maxn];
void calheight(int *r, int *sa, int n)
{
int i, j, k = 0;
for(i = 1; i <= n; i++) rank[sa[i]] = i;
for(i = 0; i < n; height[rank[i++]] = k)
for(k ? k-- : 0, j = sa[rank[i] - 1]; r[i + k] == r[j + k]; k++);
return;
}
char in[maxn];
int s[maxn], sa[maxn], belong[maxn], N;
lint ans[maxn];
lint C[maxn];
int lowbit(int x)
{
return -x & x;
}
void add(int L, lint value)
{
while(L <= N)
C[L] = (C[L] + value) % mod, L += lowbit(L);
return;
}
void update(int L, int R, lint value)//区间更新[L, R] += value
{
add(L, value), add(R + 1, (-value + mod) % mod);
}
lint query(int pos)//单点查询
{
lint ret = 0;
while(pos)
ret = (ret + C[pos]) % mod, pos -= lowbit(pos);
return ret;
}
void dfs(int L, int R, int h)
{
int i = L;
while(i <= R)
{
while(i <= R && height[i] == h) i++;
if(i > R) break;
lint cnt[4]; memset(cnt, 0, sizeof(cnt));
int j = i;
cnt[belong[sa[j - 1]]]++;
int nexh = height[i];
while(j <= R && height[j] > h)
cnt[belong[sa[j]]]++, nexh = min(nexh, height[j]), j++;
// for(int k = h + 1; k <= nexh; k++)
// ans[k] = (ans[k] + cnt[1]*cnt[2]*cnt[3] % mod) % mod;
update(h + 1, nexh, cnt[1]*cnt[2]*cnt[3] % mod);
dfs(i, j - 1, nexh);
i = j;
}
return;
}
void solve(int mlen)
{
dfs(1, N, 0);
for(int i = 1; i <= mlen; i++)
printf("%I64d ", query(i));
}
int main()
{
int minlen = 1e9;
N = 0;
for(int i = 1; i <= 3; i++)
{
scanf("%s", in);
int len = strlen(in);
minlen = min(minlen, len);
for(int j = 0; j < len; j++)
{
belong[N] = i;
s[N++] = in[j] - 'a' + 1;
}
belong[N] = -1;
s[N++] = 26 + i;
}
N--;
s[N] = 0;
da(s, sa, N + 1, 30);
calheight(s, sa, N);
solve(minlen);
return 0;
}