题意:给出两个字符串A,B和一个数字K。
计算S = {(i,j,k) | k >= K ,A(i,k) == B(j,k)}这个集合的元素个数。
思路: 首先后缀数组处理出 high[],然后对high[]建立线段树,线段树的每个节点记录两个信息:对应区间的最小值和最小值的位置(下标),继而对线段树进行dfs。
dfs的过程都写在注释里了。
void dfs(int l,int r,LL &anw,int len,int Low)
{
if(l > r)
return ;
N tmp = Query(1,1,len,l,r);//询问此段区间内的最小值及其位置。
if(tmp.Min >= Low)
{
//此段区间内分属A,B串的个数相乘然后与可取长度的个数相乘。
anw += (ans1[r]-ans1[l-2])*(ans2[r]-ans2[l-2])*(tmp.Min-Low+1);
dfs(l,tmp.site-1,anw,len,tmp.Min+1);//从最小值处分开继续dfs,[l,r]内的tmp.Min均以计算,故加一。
dfs(tmp.site+1,r,anw,len,tmp.Min+1);
return ;
}
dfs(l,tmp.site-1,anw,len,Low);
dfs(tmp.site+1,r,anw,len,Low);
}
下面为全部代码
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <queue>
#include <cmath>
#include <stack>
#include <map>
#include <ctime>
#include <iomanip>
#pragma comment(linker, "/STACK:1024000000");
#define EPS (1e-6)
#define LL long long
#define ULL unsigned long long
#define _LL __int64
#define INF 0x3f3f3f3f
#define Mod 1000000007
using namespace std;
const int MAXN = 200510;
char s[MAXN];
int Rank[2*MAXN],sa[2*MAXN],tr[2*MAXN],high[MAXN];
struct EDGE
{
int v,next;
}edge[2*MAXN];
int tail[MAXN],Top;
inline void Link(int u,int v)
{
edge[Top].v = v;
edge[Top].next = -1;
edge[tail[u]].next = Top;
tail[u] = Top++;
}
void Get_SA(char *s,int n,int m)
{
memset(Rank,0,sizeof(Rank));
memset(sa,0,sizeof(sa));
int i,j,k,ans,site;
for(i = max(n,m);i >= 0; --i)
tail[i] = i,edge[i].next = -1;
Top = max(n,m)+1;
for(i = 1; i <= n; ++i)
Link(s[i]-'A',i);
ans = 1,site = 1;
for(i = 0; i <= m; ++i)
{
for(j = edge[i].next; j != -1; j = edge[j].next)
sa[site++] = edge[j].v,Rank[edge[j].v] = ans;
if(edge[i].next != -1)
ans++;
tail[i] = i,edge[i].next = -1;
}
for(k = 1;k <= n; k <<= 1)
{
Top = n+1;
for(i = 1;i <= n; ++i)
Link(Rank[sa[i]+k],sa[i]);
site = 1;
for(i = 0;i <= n; ++i)
{
for(j = edge[i].next;j != -1; j = edge[j].next)
sa[site++] = edge[j].v;
tail[i] = i,edge[i].next = -1;
}
Top = n+1;
for(i = 1;i <= n; ++i)
Link(Rank[sa[i]],sa[i]);
site = 1;
for(i = 1;i <= n; ++i)
{
for(j = edge[i].next;j != -1; j = edge[j].next)
sa[site++] = edge[j].v;
tail[i] = i,edge[i].next = -1;
}
for(tr[sa[1]] = 1,i = 2,ans = 1;i <= n; ++i)
{
if(Rank[sa[i]] != Rank[sa[i-1]] || Rank[sa[i]+k] != Rank[sa[i-1]+k])
ans++;
tr[sa[i]] = ans;
}
for(i = 1;i <= n; ++i)
Rank[i] = tr[i];
if(ans >= n)
break;
}
for(i = 1,k = 1;i <= n; ++i)
{
if(k) k--;
if(Rank[i] == 1) {k = 0;high[1] = n-sa[1]+1;continue;}
j = sa[Rank[i]-1];
while(i+k <= n && j+k <= n && s[i+k] == s[j+k])
k++;
high[Rank[i]] = k;
}
//
// for(i = 1;i <= n; ++i)
// printf("i = %2d SA = %2d Rank = %2d high = %2d\n",i,sa[i],Rank[i],high[i]);
//以上为Rank,SA,HIGH的构造过程
}
LL ans1[MAXN],ans2[MAXN];
struct N
{
int site,Min;
}st[4*MAXN];
void Init(int site,int l,int r)
{
if(l == r)
{
st[site].Min = high[l],st[site].site = l;
return ;
}
int mid = (l+r)>>1;
Init(site<<1,l,mid);
Init(site<<1|1,mid+1,r);
st[site] = st[site<<1].Min < st[site<<1|1].Min ? st[site<<1] : st[site<<1|1];
}
N Query(int site,int L,int R,int l,int r)
{
if(L == l && R == r)
return st[site];
int mid = (L+R)>>1;
if(r <= mid)
return Query(site<<1,L,mid,l,r);
if(mid < l)
return Query(site<<1|1,mid+1,R,l,r);
N t1 = Query(site<<1,L,mid,l,mid);
N t2 = Query(site<<1|1,mid+1,R,mid+1,r);
if(t1.Min < t2.Min)
return t1;
return t2;
}
//l,r为左右区间端点。anw为最终答案。len为high[]的size,Low为此次dfs中符合要求的最小值,初始时为输入的K。
void dfs(int l,int r,LL &anw,int len,int Low)
{
if(l > r)
return ;
N tmp = Query(1,1,len,l,r);//询问此段区间内的最小值及其位置。
if(tmp.Min >= Low)
{
//此段区间内分属A,B串的个数相乘然后与可取长度的个数相乘。
anw += (ans1[r]-ans1[l-2])*(ans2[r]-ans2[l-2])*(tmp.Min-Low+1);
dfs(l,tmp.site-1,anw,len,tmp.Min+1);//从最小值处分开继续dfs,[l,r]内的tmp.Min均以计算,故加一。
dfs(tmp.site+1,r,anw,len,tmp.Min+1);
return ;
}
dfs(l,tmp.site-1,anw,len,Low);
dfs(tmp.site+1,r,anw,len,Low);
}
int main()
{
int n,k,len,i;
while(scanf("%d",&k) && k)
{
scanf("%s",s+1);
n = strlen(s+1);
scanf("%s",s+n+2);
s[n+1] = 'z'+1;
Get_SA(s,len = strlen(s+1),200);
for(ans1[0] = 0,ans2[0] = 0, i = 1;i <= len; ++i)
{
ans1[i] = ans1[i-1],ans2[i] = ans2[i-1];
if(sa[i] <= n) ans1[i]++;
if(sa[i] > n+1) ans2[i]++;
}
Init(1,1,len);
LL anw = 0;
dfs(2,len,anw,len,k);
printf("%I64d\n",anw);
}
return 0;
}