题意
给定n个字符串,m次询问,每次给定两个字符串s1,s2,问有多少字符串既有s1前缀又有s2后缀
分析
给字符串按照字典序排序,建立字典树。翻转后再排序建立另一个字典树。那么每次询问就是分别走两个字典树的子树的交集。
排序后,前/后缀相同的字符串的编号一定是连续的一段,所以我们可以把问题转化为二维数点,每次给定一个矩形,问框住的点的个数。
这样复杂度就是一只log的,足以通过
其实可以排序后建立可持久化Trie,然后像在主席树上找第k大那样操作就能做到线性!(略)
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=5e5+5;
int n,m;
struct ss
{
string s;
int id;
}s1[maxn];
bool cmp(ss x,ss y)
{
return x.s<y.s;
}
int get(char cc)
{
if(cc=='A') return 1;
if(cc=='C') return 2;
if(cc=='U') return 3;
return 0;
}
struct Trie
{
int tr[maxn*10][5],tot;
int l[maxn*10],r[maxn*10];
void insert(string s,int id)
{
int now=1,len=s.size();
for(int i=0;i<len;i++)
{
int dig=get(s[i]);
if(!tr[now][dig]) tr[now][dig]=++tot,l[tot]=id;
now=tr[now][dig];
r[now]=id;
}
}
int query(string s)
{
int len=s.size();
int now=1;
for(int i=0;i<len;i++)
{
int dig=get(s[i]);
now=tr[now][dig];
if(!now) return 0;
}
return now;
}
}A,B;
int cnt;
struct point
{
int x,y,t,id;
}e[maxn*10];
bool cmpp(point a,point b)
{
if(a.x==b.x)
{
if(a.y==b.y) return a.t<b.t;
return a.y<b.y;
}
return a.x<b.x;
}
int flag[maxn];
ll ans[maxn],c[maxn*10];
int lowbit(int x)
{
return x&(-x);
}
void add(int x)
{
while(x<=cnt)
{
c[x]++;
x+=lowbit(x);
}
}
ll query(int x)
{
ll res=0;
while(x)
{
res+=c[x];
x-=lowbit(x);
}
return res;
}
char aa[maxn*10];
int main()
{
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
scanf("%d%d",&n,&m);
A.tot=B.tot=1;
for(int i=1;i<=n;i++) scanf("%s",aa),s1[i].s=aa;
sort(s1+1,s1+n+1,cmp);
for(int i=1;i<=n;i++) s1[i].id=i,A.insert(s1[i].s,s1[i].id);
for(int i=1;i<=n;i++) reverse(s1[i].s.begin(),s1[i].s.end());
sort(s1+1,s1+n+1,cmp);
for(int i=1;i<=n;i++) B.insert(s1[i].s,i);
for(int i=1;i<=n;i++)
e[++cnt]=(point){i,s1[i].id,0,i};
string a;
for(int i=1;i<=m;i++)
{
scanf("%s",aa); a=aa;
int t1=A.query(a);
scanf("%s",aa); a=aa;
reverse(a.begin(),a.end());
int t2=B.query(a);
if(!t1 || !t2)
{
flag[i]=1;
continue;
}
int l1=A.l[t1],r1=A.r[t1];
int l2=B.l[t2],r2=B.r[t2];
e[++cnt]=(point){l2-1,l1-1,1,i};
e[++cnt]=(point){r2,r1,1,i+m};
e[++cnt]=(point){l2-1,r1,1,i+m*2};
e[++cnt]=(point){r2,l1-1,1,i+m*3};
}
sort(e+1,e+cnt+1,cmpp);
for(int i=1;i<=cnt;i++)
{
if(!e[i].t) add(e[i].y);
else ans[e[i].id]=query(e[i].y);
}
for(int i=1;i<=m;i++)
{
if(flag[i]) printf("0\n");
else printf("%lld\n",ans[i]+ans[i+m]-ans[i+m*2]-ans[i+m*3]);
}
return 0;
}