题目链接:点这里!!!
题意:
给你一个字符x和一个字符串s,问你s中有多少个不相同的子串?且必须含有字符x。
题解:
1、我们可以利用后缀数组来做这道题。我们求出我们要的三个数组ra,sa,height。(后缀数组求出来的三个数组)
2、我们知道一个后缀能够贡献出n-sa[i]+1-height[i]个不相同的子串,而我们要包含字符x的话,我们就只要标记出某个位置后面第一个x的位置就可以了(记为nc[i]),我们就能求到一个后缀能够贡献出n-1-max(sa[i]+height[i],nc[sa[i]])。
代码:
#include<cstdio>
#include<cstring>
#include<iostream>
#include<sstream>
#include<algorithm>
#include<vector>
#include<bitset>
#include<set>
#include<queue>
#include<stack>
#include<map>
#include<cstdlib>
#include<cmath>
#define LL long long
#define pb push_back
#define pa pair<int,int>
#define clr(a,b) memset(a,b,sizeof(a))
#define lson lr<<1,l,mid
#define rson lr<<1|1,mid+1,r
#define bug(x) printf("%d++++++++++++++++++++%d\n",x,x)
#define key_value ch[ch[root][1]][0]
#pragma comment(linker, "/STACK:102400000000,102400000000")
const LL MOD = 1000000007;
const int N = 1e5+15;
const int maxn = 1e6+15;
const int letter = 130;
const int INF = 1e9+7;
const double pi=acos(-1.0);
const double eps=1e-10;
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
char x,s[N];
int sa[N],ra[N],height[N],c[maxn],n,t2[N],t[N],k,nc[N];
void build_sa(int n,int m){
int *x=ra,*y=t2;
for(int i=0;i<m;i++)c[i]=0;
for(int i=0;i<n;i++) c[x[i]=s[i]]++;
for(int i=1;i<m;i++)c[i]+=c[i-1];
for(int i=n-1;i>=0;i--) sa[--c[x[i]]]=i;
for(int k=1;k<=n;k<<=1){
int p=0;
for(int i=n-k;i<n;i++) y[p++]=i;
for(int i=0;i<n;i++) if(sa[i]>=k) y[p++]=sa[i]-k;
for(int i=0;i<m;i++) c[i]=0;
for(int i=0;i<n;i++) c[x[y[i]]]++;
for(int i=1;i<m;i++) c[i]+=c[i-1];
for(int i=n-1;i>=0;i--) sa[--c[x[y[i]]]]=y[i];
swap(x,y);
p=1,x[sa[0]]=0;
for(int i=1;i<n;i++){
x[sa[i]]=y[sa[i-1]]==y[sa[i]]&&y[sa[i-1]+k]==y[sa[i]+k]?p-1:p++;
}
if(p>=n) break;
m=p;
}
}
void build_height(int n){
int k=0;
for(int i=0;i<n;i++) ra[sa[i]]=i;
for(int i=0;i<n;i++){
if(k)k--;
int j=sa[ra[i]-1];
while(s[i+k]==s[j+k])k++;
height[ra[i]]=k;
}
}
int main(){
int T,cas=0;
scanf("%d",&T);
while(T--){
getchar();
clr(s,0);
scanf("%c%s",&x,s);
n=strlen(s);
s[n++]=0;
build_sa(n,256);
build_height(n);
for(int i=0;i<n;i++)nc[i]=n-1;
for(int i=0;i<n;i++){
if(s[i]==x) {
nc[i]=i;
for(int j=i-1;j>=0&&s[j]!=x;j--) nc[j]=i;
}
}
LL sum=0;
for(int i=1;i<n;i++){
sum+=1ll*(n-1)-1ll*max(sa[i]+height[i],nc[sa[i]]);
}
printf("Case #%d: %I64d\n",++cas,sum);
}
return 0;
}