题目:http://codeforces.com/contest/1313/problem/E
参考:http://codeforces.com/blog/entry/74146
题意:给定字符串a,b,s。求子串组合数使得
a
[
l
1
,
r
1
]
+
b
[
l
2
,
r
2
]
=
=
s
a[l1,r1]+b[l2,r2]==s
a[l1,r1]+b[l2,r2]==s,要求
[
l
1
,
r
1
]
,
[
l
2
,
r
2
]
[l1,r1],[l2,r2]
[l1,r1],[l2,r2]交集非空。
题解:用z算法求出a在s中的最长前缀
l
c
p
lcp
lcp,b在s中的最长后缀
l
c
s
lcs
lcs。从左到右枚举a字符串,取定
l
1
l_1
l1,那么相应的
r
2
r_2
r2取值为
l
1
<
=
r
2
<
=
l
1
+
m
−
2
l_1<=r_2<=l_1+m-2
l1<=r2<=l1+m−2,我们把满足当前区间的所有
r
2
r_2
r2对应的最左端点
r
2
−
l
c
s
r
2
r_2-lcs_{r_2}
r2−lcsr2都扔进树状数组上,分别统计数量与总和
c
n
t
,
s
u
m
cnt,sum
cnt,sum。
那么当前
l
1
l_1
l1为左端点的情况数有
l
c
s
∗
c
n
t
−
s
u
m
lcs*cnt-sum
lcs∗cnt−sum,
s
u
m
sum
sum部分是取不到的情况,需要减去。这只是大概抽象的说明,代码实现和理论说明有出入,详见代码。
#include<bits/stdc++.h>
using namespace std;
const int maxn = 500010;
#define ll long long
int n,m;
char a[maxn],b[maxn];
char s[maxn*2],c[maxn*3];
int z[maxn*3];
int lcp[maxn];//lcp[i]表示a[i]开始,能匹配s的最长前缀
int lcs[maxn];//lcs[i]表示b[i]开始,能匹配s的最长后缀
//z算法,求解给定字符串每个位置能匹配自身的最长前缀
void z_init(int len) {
int l = 1,r = 1;
z[1] = len;
for(int i = 2;i <= len;i++) {
if(i > r) {
l = i;r = i;
while(r<=len && c[r-i+1]==c[r]) r++;
z[i] = r-l;r--;
}else {
int k = i-l+1;
if(z[k]<r-i+1) z[i] = z[k];
else {
l = i;
while(r<=len && c[r-i+1]==c[r]) r++;
z[i] = r-l;r--;
}
}
}
}
void init() {
//求lcp
for(int i = 1;i <= m;i++) c[i] = s[i];
c[m+1] = '#';
for(int i = 1;i <= n;i++) c[m+1+i] = a[i];
c[n+m+2] = '\0';
z_init(n+m+1);
for(int i = 1;i <= n;i++) lcp[i] = z[m+1+i];
//求lcs
for(int i = 1;i <= m;i++) c[i] = s[m+1-i];
c[m+1] = '#';
for(int i = 1;i <= n;i++) c[m+1+i] = b[n+1-i];
c[n+m+2] = '\0';
z_init(n+m+1);
for(int i = 1;i <= n;i++) lcs[i] = z[m+1+n+1-i];
}
//Fenwick Tree
ll cnt[maxn*2],sum[maxn*2];
int lowbit(int x) {
return x&(-x);
}
void add(int v) {
int x = v;
while(x <= n) sum[x]+=v,cnt[x]++,x+=lowbit(x);
}
void sub(int v) {
int x = v;
while(x <= n) sum[x]-=v,cnt[x]--,x+=lowbit(x);
}
ll get_sum(int x) {
ll res = 0;
while(x) res+=sum[x],x-=lowbit(x);
return res;
}
ll get_cnt(int x) {
ll res = 0;
while(x) res+=cnt[x],x-=lowbit(x);
return res;
}
int main() {
scanf("%d%d",&n,&m);
scanf("%s%s%s",a+1,b+1,s+1);
init();//puts("VE");
ll ans = 0;
/*
*l1 <= r2 <= l1+m-2 as |r2-l1| <= m-1
*initial l1 = 1
*/
for(int i = 1;i <= min(n,m-1);i++) add(max(1,m-lcs[i]));//puts("S");
for(int i = 1,r;i <= n;i++) {
r = min(m-1,lcp[i]);
ans += 1LL*(r+1)*get_cnt(r)-get_sum(r);
sub(max(1,m-lcs[i]));//delete case of "i as r2"
if(i+m-1 <= n) add(max(1,m-lcs[i+m-1]));//add case of "i+m-1 as r2"
}
printf("%I64d\n",ans);
}