Description
给定n个字符串和q个询问
每次询问在这n个字符串中,有多少个字符串同时满足
1. 字符串a是它的前缀
2. 字符串b是它的后缀
100%数据满足n,q≤50000,字符串长度不超过100,任意两串最长公共前缀和最长公共后缀较短(interesting)
Analysis
我比赛的思路是将字符串按前缀字典序排序,然后二分出一段区间
那么我们需要统计的就是这段区间内的字符串有多少个b是它的后缀
考场想的哈希+vector上二分来算,GG了
其实可以也按后缀排序,也二分出一段区间
于是询问就是形如第一维坐标在[x1,y1]内,第二维坐标在[x2,y2]内的点数
这是经典套路题,可以将询问拆成4个,容斥搞搞
那么对于单个询问,可以按第一维为关键字排序,第二维维护一个树状数组
至于字符串的排序,因为题目特定条件,可以快拍,当然桶排也行
Code
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,b,a) for(int i=b;i>=a;i--)
using namespace std;
const int N=50002,L=102;
int n,m,num,l1,l2,ra[N],rb[N],d[N],c[N],BIT[N],ans[N*4];
char s1[L],s2[L];
struct lyd
{
int l,id;
char s[L];
}a[N],b[N];
struct node
{
int x,y,id;
}f[N*4];
bool cmppx(lyd a,lyd b)
{
fo(i,1,min(a.l,b.l))
if(a.s[i]!=b.s[i]) return a.s[i]<b.s[i];
return a.l<b.l;
}
bool cmpc(node a,node b)
{
return a.x<b.x || a.x==b.x && a.y<b.y;
}
int cmp(char *s,lyd a)
{
int l=strlen(s+1);
fo(i,1,min(l,a.l))
if(s[i]<a.s[i]) return 1;
else
if(s[i]>a.s[i]) return 0;
if(l<a.l) return 1;
if(l>a.l) return 0;
return 2;
}
int findl(int l,int r,int bz)
{
int mid;
while(l<r)
{
mid=(l+r)>>1;
int k;
if(!bz)
{
int t=a[mid].l;
a[mid].l=l1;
k=cmp(s1,a[mid]);
a[mid].l=t;
}
else
{
int t=b[mid].l;
b[mid].l=l2;
k=cmp(s2,b[mid]);
b[mid].l=t;
}
if(k>0) r=mid;
else l=mid+1;
}
return l;
}
int findr(int l,int r,int bz)
{
int mid;
while(l<r)
{
mid=(l+r)>>1;
int k;
if(!bz)
{
int t=a[mid].l;
a[mid].l=l1;
k=cmp(s1,a[mid]);
a[mid].l=t;
}
else
{
int t=b[mid].l;
b[mid].l=l2;
k=cmp(s2,b[mid]);
b[mid].l=t;
}
if(k==0 || k==2) l=mid+1;
else r=mid;
}
return l;
}
int lowbit(int x)
{
return x&-x;
}
void add(int x)
{
for(int i=x;i<=n;i+=lowbit(i)) BIT[i]++;
}
int get(int x)
{
int t=0;
for(int i=x;i;i-=lowbit(i)) t+=BIT[i];
return t;
}
int main()
{
scanf("%d %d\n",&n,&m);
fo(i,1,n)
{
scanf("%s\n",a[i].s+1),a[i].l=strlen(a[i].s+1);
b[i].l=a[i].l;
fo(j,1,b[i].l) b[i].s[j]=a[i].s[b[i].l-j+1];
a[i].id=b[i].id=i;
}
sort(a+1,a+n+1,cmppx);
fo(i,1,n) ra[a[i].id]=i;
sort(b+1,b+n+1,cmppx);
fo(i,1,n) rb[b[i].id]=i;
fo(i,1,n) d[ra[i]]=rb[i];
a[n+1].s[a[n+1].l=1]='z',b[n+1].s[a[n+1].l=1]='z';
fo(k,1,m)
{
scanf("%s\n%s\n",s1+1,s2+1);
l1=strlen(s1+1),l2=strlen(s2+1);
fo(i,1,l2/2) swap(s2[i],s2[l2-i+1]);
int x1=findl(0,n+1,0),y1=findr(0,n+1,0);y1--;
int x2=findl(0,n+1,1),y2=findr(0,n+1,1);y2--;
c[k]=++num;
f[num].x=y1,f[num].y=y2,f[num].id=num;
f[++num].x=x1-1,f[num].y=y2,f[num].id=num;
f[++num].x=y1,f[num].y=x2-1,f[num].id=num;
f[++num].x=x1-1,f[num].y=x2-1,f[num].id=num;
}
sort(f+1,f+num+1,cmpc);
int j=1;
fo(i,1,num)
{
int x=f[i].x,y=f[i].y;
while(j<=x) add(d[j]),j++;
ans[f[i].id]=get(y);
}
fo(i,1,m)
{
int p=c[i];
printf("%d\n",ans[p]-ans[p+1]-ans[p+2]+ans[p+3]);
}
return 0;
}