题目链接http://acm.hdu.edu.cn/showproblem.php?pid=576
题目意思也就是给你一个字符x,问字符串str中有多少个包含x的子串(x在子串中至少出现一次)比赛是第一思路就是后缀数组,但无奈是今天早上才开始学习后缀数组,下午比赛的时候又是一脸懵逼,比赛后发现挺简单,最近几天也系统的学习一下后缀数组和字符串问题
首先字符串str中子串的总个数为sum(len-(sa[i]+height[i]))然后就是求解包含字符x的子串个数,只需每个离sa[i]最近位置的x的位置即可,答案即为sum(len-max(next[sa[i]],sa[i]+length[i]) 赛后补题套的是kuangbin的模板,莫名的wa,换了模板之后果断AC(同时注意答案为long long),代码奉上
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cstdio>
using namespace std;
const int maxn=100000+5;
int len,n;
int nxt[maxn];
int num[maxn];
int sa[maxn], Rank[maxn], height[maxn];
int wa[maxn], wb[maxn], wv[maxn], wd[maxn];
int cmp(int *r, int a, int b, int l){
return r[a] == r[b] && r[a+l] == r[b+l];
}
void da(int *r, int n, int m){
int i, j, p, *x = wa, *y = wb, *t;
for(i = 0; i < m; i ++) wd[i] = 0;
for(i = 0; i < n; i ++) wd[x[i]=r[i]] ++;
for(i = 1; i < m; i ++) wd[i] += wd[i-1];
for(i = n-1; i >= 0; i --) sa[-- wd[x[i]]] = i;
for(j = 1, p = 1; p < n; j *= 2, m = p){
for(p = 0, i = n-j; i < n; i ++) y[p ++] = i;
for(i = 0; i < n; i ++) if(sa[i] >= j) y[p ++] = sa[i] - j;
for(i = 0; i < n; i ++) wv[i] = x[y[i]];
for(i = 0; i < m; i ++) wd[i] = 0;
for(i = 0; i < n; i ++) wd[wv[i]] ++;
for(i = 1; i < m; i ++) wd[i] += wd[i-1];
for(i = n-1; i >= 0; i --) sa[-- wd[wv[i]]] = y[i];
for(t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n; i ++){
x[sa[i]] = cmp(y, sa[i-1], sa[i], j) ? p - 1: p ++;
}
}
}
void calHeight(int *r, int n){ // 求height数组。
int i, j, k = 0;
for(i = 1; i <= n; i ++) Rank[sa[i]] = i;
for(i = 0; i < n; height[Rank[i ++]] = k){
for(k ? k -- : 0, j = sa[Rank[i]-1]; r[i+k] == r[j+k]; k ++);
}
}
char aim,str[maxn];
int pos;
int main()
{
int l=1;
int T;
scanf("%d",&T);
while(T--)
{
memset(nxt,0,sizeof(nxt));
scanf(" %c",&aim);
scanf(" %s",str);
len=strlen(str);
for(int i=0;i<len;i++)num[i]=str[i]-'a'+1;
num[len]=0;
da(num,len+1,128);
calHeight(num,len);
nxt[len]=len;
for(int i=len-1;i>=0;i--)
{
nxt[i]=(str[i]==aim)?i:nxt[i+1];
}
// for(int i=0;i<=len;i++)
// {
// printf("%d ",sa[i]);
// }
// cout <<endl;
// for(int i=0;i<=len;i++)
// {
// printf("%d ",height[i]);
// }
// cout << endl;
long long ans=0;
for(int i=1;i<len+1;i++)
{
//cout << max(nxt[sa[i]],sa[i]+height[i]) << endl;
ans+=(len-max(nxt[sa[i]],sa[i]+height[i]));
}
printf("Case #%d: %I64d\n",l++,ans);
}
}