题目要求:求s1,s2的最大子串
思路:将s1,s2合并为一个字符串s, 也就是求s的max(lcp[i][j]), 唯一i,j分别位于s1和s2, 利用后缀数组计算s的h[]或者height[]数组,那么答案即为h中的最大值。
证明如下:s1,s2一定存在最大子串t,t为s1的子串t1和s2的子串t2的lcp[t1][t2], 假设t1,t2在后缀数组中不相邻,则任意取后缀数组中位于t1,t2之间的串tt, 则lcp[tt][t1]与lcp[tt][t2]中至少有一个同时满足
(1) >=lcp[t1][t2] :这是由后缀数组的性质所决定的
(2) 两串分别位于s1,s2:这是因为t1,t2分别属于s1,s2
从而得到更优的解,所以答案一定为h中的最大值
- #include <iostream>
- #include <string>
- #include <algorithm>
- using namespace std;
- #define Min(a,b) (a)<(b)?(a):(b)
- const int N = 201000;
- int n, m;
- char s[N], s2[N];
- int cnt[N], mem[4][N], *rank, *nrank, *sa, *nsa, h[N];
- // lcp[i][j]: longest commen prefix ( suffix(sa[k+1]), suffix(sa[k]) ) j <= k < j+2^i
- void radix_sort()
- {
- int i, j, k;
- rank = mem[0];
- nrank = mem[1];
- sa = mem[2];
- nsa = mem[3];
- for(i = 0; i < n; i++) cnt[s[i]]++;
- for(i = 1; i < 256; i++) cnt[i] += cnt[i-1];
- for(i = n-1; i >= 0; i--) sa[--cnt[s[i]]] = i;
- for(rank[0]=0, i=1; i < n; i++)
- {
- rank[sa[i]] = rank[sa[i-1]];
- if(s[sa[i]]!=s[sa[i-1]]) rank[sa[i]]++;
- }
- for(k = 1; k<n && rank[sa[n-1]] < n-1; k*=2)
- {
- for(i = 0; i < n; i++) cnt[rank[sa[i]]] = i+1;
- for(i = n-1; i >= 0; i--) if(sa[i]-k>=0)
- nsa[--cnt[rank[sa[i]-k]]] = sa[i]-k;
- // max(sa[i]-k)=n-k-1 , therefore i = n-k;
- for(i = n-k; i < n; i++)
- nsa[--cnt[rank[i]]] = i;
- for(nrank[nsa[0]], i=1; i < n; i++)
- {
- nrank[nsa[i]] = nrank[nsa[i-1]];
- if(rank[nsa[i]] != rank[nsa[i-1]]
- || rank[nsa[i]+k] != rank[nsa[i-1]+k])
- nrank[nsa[i]]++;
- }
- swap(rank, nrank);
- swap(sa, nsa);
- }
- }
- void get_lcp_rmq()
- {
- int i, j, k;
- for(i=0,k=0; i<n; i++)
- {
- if(rank[i]==n-1) h[rank[i]]=k=0;
- else
- {
- if(k>0)k--;
- j = sa[rank[i]+1];
- for(;s[i+k]==s[j+k];k++) ;
- h[rank[i]]=k;
- }
- }
- }
- int main()
- {
- int i, j, k;
- int p1, p2, n1;
- gets(s);
- n1 = strlen(s);
- s[n1++]='#';
- gets(s2);
- strcat(s,s2);
- n = strlen(s);
- s[n++]=0;
- radix_sort();
- get_lcp_rmq();
- int ans = 0;
- for(i = 0; i < n-1; i++)
- {
- j = sa[i];
- if(j < n1)p1 = 1;
- else p1 = -1;
- k = sa[i+1];
- if(k < n1)p2 = 1;
- else p2 = -1;
- if(p1*p2<1 && h[i]>ans)
- ans = h[i];
- }
- printf("%d/n", ans);
- return 0;
- }