一道字符串好题。网上大多解法是后缀数组+栈扫描,大致思想是按K分组以后统计s,t中不同公共字串的个数。但这种方法我实在是理解不能......
其实,用后缀自动机来考虑这个题可能会好一点。首先,我们给S串建立后缀自动机,然后让T串在S串上跑一下匹配,这是后缀自动机的经典操作。这样我们可以求得T的每一位向前最大能匹配多少位,用这个数目来统计答案。但是有一个问题,就是假设T的某一位u匹配了x长度,那么有可能S中有多个位置对应和T的u匹配x长度,于是我们首先就要知道每个点对应多少重复的字串。这恰恰应用了后缀自动机第二个经典操作。
好,我们来理一下思路:首先,给S建自动机以后,我们用排序+递推的方法求出每个点包含的后缀个数假设为sum;然后用T在S上跑匹配,对于T的每一位,我们匹配到了S自动机上的u节点,如果这个节点的len大于k,我们就不断地回溯它的父亲节点,直到最后一个len>=k的节点v,我们要统计u->v这条路径上对应的所有公共子串去更新答案。其实,这个过程就是我们先找到了一个匹配的最大后缀,然后不断地从前面缩短它,然后把每段子串都统计一下出现次数就行了。这个回溯的过程可以记忆化搜索一下。
1 #include<iostream> 2 #include<cstdio> 3 #include<algorithm> 4 #include<cmath> 5 #include<cstring> 6 #define maxn 100020 7 using namespace std; 8 typedef long long LL; 9 struct node 10 { 11 node *ch[53],*pre; 12 LL f,len,v; 13 void clear() 14 { 15 memset(ch,0,sizeof(ch)); 16 pre=0,len=0; 17 } 18 }sam[maxn*2],*rot,*now,*que[maxn*2]; 19 int wv[maxn]; 20 char s[maxn],t[maxn]; 21 int m,num; 22 23 void insert(int w) 24 { 25 node *p=now,*np=&sam[++num]; 26 np->len=p->len+1; 27 np->v=1; np->f=-1; 28 while (p&&p->ch[w]==0) p->ch[w]=np,p=p->pre; 29 if (!p) np->pre=rot; 30 else 31 { 32 node *q=p->ch[w]; 33 if (q->len==p->len+1) np->pre=q; 34 else 35 { 36 node *nq=&sam[++num]; 37 *nq=*q; 38 nq->len=p->len+1; 39 nq->v=0; nq->f=-1; 40 np->pre=q->pre=nq; 41 while (p&&p->ch[w]==q) p->ch[w]=nq,p=p->pre; 42 } 43 } 44 now=np; 45 } 46 47 LL back(node *p) 48 { 49 if (p->f!=-1) return p->f; 50 if (p->len<m) return p->f=0; 51 if (p->pre->len<m) return p->f=(p->v)*(p->len-(LL)m+1); 52 return p->f=back(p->pre)+(p->v)*(p->len-p->pre->len); 53 } 54 55 int main() 56 { 57 //freopen("com.in","r",stdin); 58 scanf("%d",&m); 59 while (m) 60 { 61 LL ans=0; 62 for (int i=0;i<=num;i++) sam[i].clear(); 63 rot=now=&sam[num=0]; 64 scanf("%s",s); 65 scanf("%s",t); 66 //对s建后缀自动机 67 int ls=strlen(s); 68 for (int i=0;i<ls;i++) 69 { 70 int w=(s[i]<='z'&&s[i]>='a')?(s[i]-'a'):(26+s[i]-'A'); 71 insert(w); 72 } 73 //求出s每个节点对应的后缀串的数目 74 for (int i=0;i<=num;i++) que[i]=0; 75 for (int i=0;i<=ls;i++) wv[i]=0; 76 for (int i=1;i<=num;i++) wv[sam[i].len]++; 77 for (int i=1;i<=ls;i++) wv[i]+=wv[i-1]; 78 for (int i=num;i;i--) que[wv[sam[i].len]--]=&sam[i]; 79 for (int i=num;i;i--) que[i]->pre->v+=que[i]->v; 80 //t串在s串上跑一边,求出跑到的节点 81 int lt=strlen(t); 82 LL tmp=0; 83 node *p=rot; 84 for (int i=0;i<lt;i++) 85 { 86 int w=(t[i]<='z'&&t[i]>='a')?(t[i]-'a'):(26+t[i]-'A'); 87 if (p->ch[w]) tmp++,p=p->ch[w]; 88 else 89 { 90 while (p&&p->ch[w]==0) p=p->pre; 91 if (p) tmp=p->len+1,p=p->ch[w]; 92 else tmp=0,p=rot; 93 } 94 //累加这个节点回溯到长度小于等于k的点的统计答案 95 if (tmp<m) continue ; 96 if (p->pre->len<m) ans=ans+(tmp-(LL)m+1)*p->v; 97 else ans+=(tmp-p->pre->len)*p->v+back(p->pre); 98 } 99 printf("%I64d\n",ans); 100 scanf("%d",&m); 101 } 102 return 0; 103 }
由于一个小错误,导致RE了半天......