题目大意:
给你一个S串和m个T串,
求S串中有多少子串不是另外m个T串的子串。
解题思路:
对于所有的串,首先将他们连接。接下来我们就要统计对于0到 strlen(s) 这些位置,每个位置有多少与它重复的子串,统计出来之后拿串的总个数减去重复的即可。
统计重复的就从前往后扫一遍 对于在S串和在T串的情况考虑。同时也要从后往前扫一遍。
最后在把S串中自己重复的子串去掉即可。
但是这题让我明白了好多小细节,之前我对于小写字母都是直接-'a',但是这个题目就会产生问题,其次每个串之间的分隔符,之前用的一直都是一个没有出现过的符号插中间,这个题目也会出现问题。可能是我对后缀数组理解的还不够深刻。。。
Ac代码:
#include<bits/stdc++.h>
#define rank ra
using namespace std;
const int maxn=3e5+10;
const int INF=1e9+7;
typedef long long ll;
char s[maxn];
int n,sa[maxn],rank[maxn],height[maxn],pos[maxn];
int t1[maxn],t2[maxn],r[maxn],c[maxn];
bool cmp(int *r,int a,int b,int l)
{
return r[a]==r[b] && r[a+l]==r[b+l];
}
void da(int str[],int sa[],int rank[],int height[],int n,int m)
{
n++;
int i,j,p,*x=t1,*y=t2;
for(int i=0;i<m;i++) c[i]=0;
for(int i=0;i<n;i++) c[x[i]=str[i]]++;
for(int i=1;i<m;i++) c[i]+=c[i-1];
for(int i=n-1;i>=0;i--) sa[--c[x[i]]]=i;
for(int j=1;j<=n;j<<=1)
{
p=0;
for(int i=n-j;i<n;i++) y[p++]=i;
for(int i=0;i<n;i++) if(sa[i]>=j) y[p++]=sa[i]-j;
for(int i=0;i<m;i++) c[i]=0;
for(int i=0;i<n;i++) c[x[y[i]]]++;
for(int i=1;i<m;i++) c[i]+=c[i-1];
for(int i=n-1;i>=0;i--) sa[--c[x[y[i]]]]=y[i];
swap(x,y);
p=1,x[sa[0]]=0;
for(int i=1;i<n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
if(p>=n) break;
m=p;
}
int k=0;
n--;
for(int i=0;i<=n;i++) rank[sa[i]]=i;
for(int i=0;i<n;i++)
{
if(k) k--;
j=sa[rank[i]-1];
while(str[i+k]==str[j+k]) k++;
height[rank[i]]=k;
}
}
int main()
{
int QAQ,kase=0;
scanf("%d",&QAQ);
while(QAQ--)
{
int m; scanf("%d",&m);
scanf(" %s",s);
int ls=strlen(s),len=0,num=30;
for(int i=0;i<ls;i++) r[len++]=s[i]-'a'+1; //这里注意+1
while(m--)
{
r[len++]=num++; //注意分隔符为num++
scanf(" %s",s);
int lk=strlen(s);
for(int i=0;i<lk;i++) r[len++]=s[i]-'a'+1;
}
r[len]=0;n=len;
da(r,sa,rank,height,n,num+1);
memset(pos,0,sizeof pos);
int tmp=INF;
for(int i=1;i<=n;i++) //从前往后统计重复子串
{
if(sa[i]<ls)
{
tmp=min(tmp,height[i]);
pos[sa[i]]=max(pos[sa[i]],tmp);
}
else tmp=INF;
}
tmp=INF;
for(int i=n;i>=1;i--) //从后往前统计重复子串
{
if(sa[i-1]<ls)
{
tmp=min(tmp,height[i]);
pos[sa[i-1]]=max(pos[sa[i-1]],tmp);
}
else tmp=INF;
}
for(int i=1;i<=n;i++) //统计自己与自己重复的子串
{
if(sa[i]<ls&&sa[i-1]<ls)
{
pos[sa[i-1]]=max(pos[sa[i-1]],height[i]);
}
}
ll res=1LL*ls*(ls+1)/2;
for(int i=0;i<ls;i++)
res-=pos[i];
printf("Case %d: %lld\n",++kase,res);
}
return 0;
}