4566: [Haoi2016]找相同字符
Time Limit: 20 Sec Memory Limit: 256 MBSubmit: 536 Solved: 298
[ Submit][ Status][ Discuss]
Description
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。
Input
两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母
Output
输出一个整数表示答案
Sample Input
aabb
bbaa
bbaa
Sample Output
10
题解:
两个串的重复子串的个数=合并后的后缀的前缀和(以两个串分别开头)
而合并后的后缀的前缀和(以两个串分别开头)=合并后的后缀的前缀和-两个串单独的前缀和
具体做法详见差异
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
const int N=400010;
int ww[N],wx[N],wy[N],wv[N];
bool cmp(int *r,int x,int x1,int ln)
{
return r[x]==r[x1]&&r[x+ln]==r[x1+ln];
}
void da(int *r,int *sa,int n,int m)
{
int *x=wx,*y=wy,*t,i,j,p;
for(i=0;i<m;i++) ww[i]=0;
for(i=0;i<n;i++) ww[x[i]=r[i]]++;
for(i=1;i<m;i++) ww[i]+=ww[i-1];
for(i=n-1;i>=0;i--) sa[--ww[x[i]]]=i;
for(j=1,p=1;p<n;j*=2,m=p)
{
for(p=0,i=n-j;i<n;i++) y[p++]=i;
for(i=0;i<n;i++) if(sa[i]>=j) y[p++]=sa[i]-j;
for(i=0;i<n;i++) wv[i]=x[y[i]];
for(i=0;i<m;i++) ww[i]=0;
for(i=0;i<n;i++) ww[wv[i]]++;
for(i=1;i<m;i++) ww[i]+=ww[i-1];
for(i=n-1;i>=0;i--) sa[--ww[wv[i]]]=y[i];
for(t=x,x=y,y=t,i=1,x[sa[0]]=0,p=1;i<n;i++)
x[sa[i]]=cmp(y,sa[i],sa[i-1],j)?p-1:p++;
}
}
int rank[N],h[N];
void calheight(int *r,int *sa,int n)
{
int i,j,k=0;
for(i=1;i<=n;i++) rank[sa[i]]=i;
for(i=0;i<n;h[rank[i++]]=k)
for(k?k--:0,j=sa[rank[i]-1];r[i+k]==r[j+k];k++);
return;
}
int q[N],t,ls[N],rs[N];
long long ans,sum;
char s1[N],s2[N];
int l1,l2,len=0;
int r[N],sa[N];
void get_sum()
{
t=0,sum=0;h[0]=h[len+1]=-1;
for(int i=1;i<=len+1;i++)
{
while(t&&h[q[t]]>=h[i]) rs[q[t--]]=i;
q[++t]=i;
}
t=0;
for(int i=len;i>=0;i--)
{
while(t&&h[q[t]]>h[i]) ls[q[t--]]=i;
q[++t]=i;
}
for(int i=1;i<=len;i++)
{
sum+=(long long)(i-ls[i])*(rs[i]-i)*h[i];
// printf("%d",rs[i]-i);
//printf("%d %d %d %d\n",h[i],i,ls[i],rs[i]);
}
}
int main()
{
scanf("%s%s",s1,s2);
l1=strlen(s1),l2=strlen(s2);
for(int i=0;i<l1;i++)
r[len++]=s1[i]-'a'+3;
r[len++]=1;
da(r,sa,len+1,259);
calheight(r,sa,len);//printf("!");
get_sum();//
//printf("\n");
ans-=sum;
for(int i=0;i<l2;i++)
r[len++]=s2[i]-'a'+3;
r[len++]=2;
da(r,sa,len+1,259);
calheight(r,sa,len);
get_sum();//printf("%d\n",sum);
ans+=sum;
len=0;
for(int i=0;i<l2;i++)
r[len++]=s2[i]-'a'+3;
da(r,sa,len+1,259);
calheight(r,sa,len);
get_sum();
ans-=sum;
printf("%lld\n",ans);
}