原题:后缀数组求不同子串个数 。比赛时看着bin神博客http://www.cnblogs.com/kuangbin/archive/2013/04/24/3039634.html YY了两个小时,莫名A掉。
我的做法就是算出包含X字符的所有子串数,再减去其中重复的就是答案。
#include <iostream>
#include <string.h>
#include <algorithm>
#include <stdio.h>
typedef long long ll;
using namespace std;
const int MAXN=200000+6;
char X;
int sum[MAXN];
bool vis[MAXN];
int res[MAXN];
int siz,pre;
/*
*suffix array
*倍增算法 O(n*logn)
*待排序数组长度为n,放在0~n-1中,在最后面补一个0
*build_sa( ,n+1, );//注意是n+1;
*getHeight(,n);
*例如:
*n = 8;
*num[] = { 1, 1, 2, 1, 1, 1, 1, 2, $ };注意num最后一位为0,其他大于0
*rank[] = { 4, 6, 8, 1, 2, 3, 5, 7, 0 };rank[0~n-1]为有效值,rank[n]必定为0无效值
*sa[] = { 8, 3, 4, 5, 0, 6, 1, 7, 2 };sa[1~n]为有效值,sa[0]必定为n是无效值
*height[]= { 0, 0, 3, 2, 3, 1, 2, 0, 1 };height[2~n]为有效值
*
*/
int sa[MAXN];//SA数组,表示将S的n个后缀从小到大排序后把排好序的
//的后缀的开头位置顺次放入SA中
int t1[MAXN],t2[MAXN],c[MAXN];//求SA数组需要的中间变量,不需要赋值
int rank2[MAXN],height[MAXN];
//待排序的字符串放在s数组中,从s[0]到s[n-1],长度为n,且最大值小于m,
//除s[n-1]外的所有s[i]都大于0,r[n-1]=0
//函数结束以后结果放在sa数组中
void build_sa(int s[],int n,int m)
{
int i,j,p,*x=t1,*y=t2;
//第一轮基数排序,如果s的最大值很大,可改为快速排序
for(i=0;i<m;i++)c[i]=0;
for(i=0;i<n;i++)c[x[i]=s[i]]++;
for(i=1;i<m;i++)c[i]+=c[i-1];
for(i=n-1;i>=0;i--)sa[--c[x[i]]]=i;
for(j=1;j<=n;j<<=1)
{
p=0;
//直接利用sa数组排序第二关键字
for(i=n-j;i<n;i++)y[p++]=i;//后面的j个数第二关键字为空的最小
for(i=0;i<n;i++)if(sa[i]>=j)y[p++]=sa[i]-j;
//这样数组y保存的就是按照第二关键字排序的结果
//基数排序第一关键字
for(i=0;i<m;i++)c[i]=0;
for(i=0;i<n;i++)c[x[y[i]]]++;
for(i=1;i<m;i++)c[i]+=c[i-1];
for(i=n-1;i>=0;i--)sa[--c[x[y[i]]]]=y[i];
//根据sa和x数组计算新的x数组
swap(x,y);
p=1;x[sa[0]]=0;
for(i=1;i<n;i++)
x[sa[i]]=y[sa[i-1]]==y[sa[i]] && y[sa[i-1]+j]==y[sa[i]+j]?p-1:p++;
if(p>=n)break;
m=p;//下次基数排序的最大值
}
}
void getHeight(int s[],int n)
{
int i,j,k=0;
for(i=0;i<=n;i++)rank2[sa[i]]=i;
for(i=0;i<n;i++)
{
if(k)k--;
j=sa[rank2[i]-1];
while(s[i+k]==s[j+k])k++;
height[rank2[i]]=k;
int tmp=0;
if(!vis[i]){
tmp=upper_bound(sum+i,sum+i+height[rank2[i]],sum[i])-(sum+i);
}
height[rank2[i]]-=tmp;
}
}
char str[MAXN];
int s[MAXN];
int main()
{
int T;
scanf("%d",&T);
int cas=1;
while(T--)
{
char cc[10];
scanf("%s",cc);
X=cc[0];
scanf("%s",str);
int n=strlen(str);
for(int i=0;i<=n;i++)s[i]=str[i];
build_sa(s,n+1,128);
memset(vis,0,sizeof(vis));
sum[0]=0;pre=-1;
if(str[0]==X){pre=0;sum[0]=1;vis[0]=true;}
for(int i=1;i<=n;i++){
if(str[i]==X){sum[i]=sum[i-1]+1;vis[i]=true;}
else sum[i]=sum[i-1];
}
siz=0;
for(int i=1;i<n;i++){
if(str[i]==X){
res[siz++]=i-pre-1;
pre=i;
}
}
res[siz++]=n-pre-1;
getHeight(s,n);
ll haha=n;
ll ans=haha*(haha+1)/2;
for(int i=0;i<siz;i++){
haha=res[i];
ans=ans-haha*(haha+1)/2;
}
for(int i=2;i<=n;i++)ans-=height[i];
printf("Case #%d: %I64d\n",cas++,ans);
}
return 0;
}