给出字符串A和B。求A的不包含B的不同字串的数量。
解法。。字符串匹配+后缀数组。
任何字串都会是某个后缀的前缀
对字串A建立后缀数组
对于排名第k的后缀 SA[k] 将会提供L-SA[k]+1个字串 其中有height[k]个前缀在之前计算过
所以每个后缀提供的新字串个数为L-SA[k]+1-height[k]
下面考虑不能包含B。对A,B做一次匹配。找出B在A中出现的位置。设为p1,p2,p3...pn;
对于排名第k后缀SA[k],其左边界为SA[k],右边界于与边界之间不能包含B。
即右边界为大于SA[k]的第一个pi+length(B)-2.
预处理出对于每一个左边界对应的最大合法长度MR[k]
则对于排名第k的后缀SA[k]对应的最大长度为 min(MR[SA[k]],L-SA[k]+1)
最小长度为 height[k]
二者之差大于0则累加答案。
由于使用了后缀数组,后缀数组本身具有匹配字符串功能。
将AB用‘$'拼接成一个新字符串,构造一次后缀数组。
就不用再写一个KMP函数。。通过两次构造后缀数组即可。
#include <iostream> #include <cstdio> #include <cstring> using namespace std; const int MAXN = 100010; char s[MAXN]; int NE[MAXN]; int HE[MAXN],TI[MAXN]; int X[MAXN],Y[MAXN],SA[MAXN]; int height[MAXN]; int MR[MAXN]; int *rank,*SV; int L,L1,L2,p,q; int NP; int getfir(int &k){ while (++k<=NP) if (HE[k]) return k; return 0; } bool eq(int i,int j,int l){ int p1,p2; p1 = SV[i]; p2 = SV[j]; if (p1!=p2) return false; p1 = i+l/2<=L? SV[i+l/2]:0; p2 = j+l/2<=L? SV[j+l/2]:0; return p1==p2; } int Jsort(int l){ memset(HE,0,sizeof(HE)); memset(TI,0,sizeof(TI)); for (int i=L-l/2+1;i<=L;i++){ if (HE[rank[i]]){ NE[TI[rank[i]]] = i; TI[rank[i]] = i; } else HE[rank[i]]=TI[rank[i]]=i; } for (int i=1;i<=L;i++){ if (SA[i]<l/2+1) continue; int now = SA[i]-l/2; if (HE[rank[now]]){ NE[TI[rank[now]]] = now; TI[rank[now]] = now; } else HE[rank[now]]=TI[rank[now]]=now; } int K=0,all=0; while (getfir(K)){ NE[TI[K]]=0; int kk = HE[K]; while (kk){ SA[++all]=kk; kk = NE[kk]; } } swap(rank,SV); for (int i=1;i<=L;i++){ if (eq(SA[i-1],SA[i],l)) rank[SA[i]]=rank[SA[i-1]]; else rank[SA[i]]=rank[SA[i-1]]+1; } NP = rank[SA[L]]; return NP; } void prework(){ memset(NE,0,sizeof(NE)); rank = X; SV = Y; NP =28; for (int i=1;i<=L;i++){ if (s[i]=='$') rank[i]=27; else if (s[i]=='#') rank[i]=28; else rank[i]=s[i]-'a'+1; } for (int i=1;i<=L;i++) SA[i]=i; } void calcheight(){ int K=0; for (int i=1;i<=L;i++){ while (s[i+K]==s[SA[rank[i]-1]+K]) K++; height[rank[i]]=K?K--:0; } } void input(){ scanf("%s",s+1); L1 = strlen(s+1); s[++L1]='$'; scanf("%s",s+L1+1); L2 = strlen(s+L1+1); L = L1+L2; } int main(){ int TT,cas=0; scanf("%d",&TT); while (TT--){ input(); prework(); for (int l=1;Jsort(l)<L;l<<=1); calcheight(); memset(MR,-1,sizeof(MR)); for (int i=rank[L1+1]+1;i<=L;i++) if (height[i]<L2) break; else MR[SA[i]]=L2-1; for (int i=L;i>0;i--) if (MR[i]<0) MR[i]=MR[i+1]+1; s[L1]=0; L = L1; prework(); for (int l=1;Jsort(l)<L;l<<=1); calcheight(); long long ans=0; for (int i=1;i<L;i++){ int r = min(MR[SA[i]],L1-SA[i]); int l = height[i]; if (r-l>0) ans+=(r-l); } printf("Case %d: %lld\n",++cas,ans); } }