题目概述
给出n个串和T个询问,每个询问求x串y串的最长公共子串。
解题报告
两个串的最长公共子串(下文简称LCS)可以用SAM方便的求出。那么这道题可以先对所有串预处理出SAM,然后就可以快速求出两个串的LCS了。
但这样还是不够快,我们可以采取这样的策略:由于一个长度为len的串与任意一个SAM匹配的时间都是
O(len)
,所以对于两个串A,B,让长度小的串和长度大的串的SAM匹配较好。
再采取进一步优化:如果
n2
<=te,预处理出所有LCS[i][j]比每组询问都处理效果来的好,反之如果
n2
>te,那么莫不如每组询问都处理(再加上记忆化来保证效率不降低,我比较懒就直接用map了:P)。
经过这两个优化之后,效率如何估计呢?由于我们都是让小的去匹配大的,所以当所有串长度相同时,效率最低,设长度为w,所有字符串总长度为sum。那么不算上map的话(map可以换成hashmap使记忆化的效率接近
O(1)
),效率应该是
O(w∗min((sumw)2,Q))
。分类讨论一下就可以得到复杂度上限:
O(Q∗sum−−−−√)
。
标题
#include<cstdio>
#include<cstring>
#include<map>
using namespace std;
typedef long long LL;
const int maxn=50000,maxm=500,maxl=100000;
int n,te,st[maxn+5],len[maxn+5],ans[maxm+5][maxm+5];
char s[maxl+maxn+5];
struct SAM
{
struct node
{
node *son[26],*fa;int MAX;
node(int M) {MAX=M;fa=0;memset(son,0,sizeof(son));}
};
typedef node* P_node;
P_node ro,lst;
SAM() {ro=new node(0);lst=ro;}
void Insert(char ch)
{
int ID=ch-'a';P_node p=lst,np=new node(p->MAX+1);
while (p&&!p->son[ID]) p->son[ID]=np,p=p->fa;
if (!p) np->fa=ro; else
{
P_node q=p->son[ID];
if (p->MAX+1==q->MAX) np->fa=q; else
{
P_node nq=new node(p->MAX+1);
memcpy(nq->son,q->son,sizeof(q->son));
nq->fa=q->fa;q->fa=np->fa=nq;
while (p&&p->son[ID]==q) p->son[ID]=nq,p=p->fa;
}
}
lst=np;
}
void make_SAM(char *s) {for (int i=1;s[i];i++) Insert(s[i]);}
int LCS(char *s)
{
P_node p=ro;int len=0,ans=0;
for (int i=1;s[i];i++)
{
int ID=s[i]-'a';
if (p->son[ID]) len++,p=p->son[ID]; else
{
while (p&&!p->son[ID]) p=p->fa;
if (p) len=p->MAX+1,p=p->son[ID]; else
len=0,p=ro;
}
if (len>ans) ans=len;
}
return ans;
}
};
SAM sam[maxn+5];
struct Pair
{
int a,b;Pair(int A=0,int B=0) {a=A;b=B;}
bool operator < (const Pair &c) const {return a<c.a||a==c.a&&b<c.b;}
};
map<Pair,int> f;
bool Eoln(char ch) {return ch==10||ch==13||ch==EOF;}
char readc()
{
static char buf[100000],*l=buf,*r=buf;
if (l==r) r=(l=buf)+fread(buf,1,100000,stdin);
if (l==r) return EOF; else return *l++;
}
int readi(int &x)
{
int tot=0,f=1;char ch=readc(),lst='+';
while ('9'<ch||ch<'0') {if (ch==EOF) return EOF;lst=ch;ch=readc();}
if (lst=='-') f=-f;
while ('0'<=ch&&ch<='9') tot=tot*10+ch-48,ch=readc();
x=tot*f;
return Eoln(ch);
}
int reads(char *s)
{
int len=0;char ch=readc();if (ch==EOF) return EOF;
s[++len]=ch;while (!Eoln(s[len])) s[++len]=readc();s[len--]=0;
return len;
}
void writei(int x)
{
static char buf[10];int len=0;
if (x<0) putchar('-'),x=-x;
do {buf[len++]=x%10+48;x/=10;} while (x);
while (len--) putchar(buf[len]);
}
int main()
{
freopen("program.in","r",stdin);
freopen("program.out","w",stdout);
readi(n);readi(te);
for (int i=1,lst=0;i<=n;i++)
{
len[i]=reads(s+(st[i]=lst));lst+=len[i]+1;
sam[i].make_SAM(s+st[i]);
}
if ((LL)n*n<te)
{
for (int i=1;i<=n;i++)
for (int j=i;j<=n;j++)
if (len[i]<len[j]) ans[i][j]=ans[j][i]=sam[j].LCS(s+st[i]); else
ans[i][j]=ans[j][i]=sam[i].LCS(s+st[j]);
while (te--)
{
int x,y;readi(x);readi(y);x++;y++;
writei(ans[x][y]);putchar('\n');
}
} else
{
f.clear();
while (te--)
{
int x,y;readi(x);readi(y);x++;y++;
if (len[x]>len[y]) swap(x,y);
if (!f.count(Pair(x,y))) f[Pair(x,y)]=sam[y].LCS(s+st[x]);
writei(f[Pair(x,y)]);putchar('\n');
}
}
return 0;
}