莫队算法
有时候我们经常会碰到这样一类问题:给定n和n个数etc,然后给出m组区间询问[L,R],要求对所有询问区间给出答案。
然后发现这类题通常有一个很好的性质就是,如果你知道了[L,R]的答案,就可以O(1)或者O(lgn)(再大就有点玄了)的知道
[L+1,R],[L-1,R],[L,R+1],[L,R-1]的答案,也就是可以很快的拓展左右节点,那么于是发现,与其在线回答每一个问题,不如找到一种离线回答顺序,
使得根据第i个区间[Li,Ri]的答案拓展到第i+1个区间[L_(i+1),R_(i+1)]的答案的用时之和最少。
好吧,上面这句话有点难懂。我们再详细点说。
不妨假设你可以做到O(1)的扩展左右端点,即由[L,R]的答案可以O(1)的知道[L+1,R],[L-1,R],[L,R+1],[L,R-1]的答案。
那么,如果你知道了[Li,Ri]的答案,你就可以在O( | Li - L_(i+1) | + | Ri - R_(i+1) | )的复杂度上知道[L_(i+1),R_(i+1)]的答案。
如果把[L,R]看作平面上的点(L,R),那么从一个点(也就是答案)扩展到另一个点(下一个答案)就是两点的曼哈顿距离。
我们可以证明存在一种走法,走过所有点,且移动距离之和是O(n*sqrt(n))的。
但是这个大概常数巨大而且不好写,所以我们有一种替代品,也可以说是改良版,就是分块。
具体做法是这样的:
我们把所有询问都进来,然后把那个长为n的序列分块,分为sqrt(n)块。
按照左端点所在块的编号(记为pos( query[i].left ))升序排序,如果左端点在同一个块内,就按照右端点升序排序。
有读者可能会问,为什么不直接按照左端点升序-右端点升序排序呢?
事实上,读者可以构造一个例子,在最坏情况下,按照这种暴力排序的方法复杂度是O(n^2)级别的。
那么问题来了:为什么莫队算法的排序保证时间复杂度是O(n*sqrt(n))的呢?
简单证明如下:
当我们第i个询问转移的第i+1个询问时
1)如果第i个询问区间和第i+1个询问区间的左端点所在块的编号相同,那么左端点的移动不会超过sqrt(n)。
也就是说,左端点一直在块内移动的总复杂度为O(n*sqrt(n))(因为左端点最多转移n次,减去左端点跨越块的部分,不足n)
同时由于右端点升序,那么若s,s+1,,,t-1,t的询问区间左端点所在块的编号相等,那么右端点的移动不会超过n次。有一位有sqrt(n)个块,
所以这一部分的复杂度是O(n*sqrt(n))的。
2)考虑左端点跨越块的情况,每次跨越最大是O(2*sqrt(n))那么左端点跨越块的复杂度O(n*sqrt(n))的。
又在这个期间,每次左端点跨越的时候,右端点可能要移动O(n)次,一共左端点跨越sqrt(n)个块,所以右端点复杂度是O(n*sqrt(n))的。
综上莫队算法的排序保证时间复杂度是O(n*sqrt(n))的。
可以进一步的说,莫队算法的总复杂度是O(m*sqrt(n)*F(n)+m*G(n))的,m是询问数。其中F(n)是拓展一次端点的复杂度,通常为O(1)或者O(lgn)。G(n)是根据现有信息计算一个区间答案的复杂度,通常为O(1)或者O(lgn)。由此可以看出,当我们有多种可以维护莫队算法的数据结构的时候,要尽量使修改操作复杂度低。一个例子是BZOJ3809的一道题,如果用树状数组,修改和查询都是O(lgn)的,那么套一个莫队复杂度就是O(m*sqrt(n)*lgn)的。但是我们可以用分块代替,这样虽然询问是O(sqrt(n))的,但是修改却变成O(1)的了,总复杂度就是O(m*sqrt(n))。
莫队算法模板:
#include<cstdio>
#include<cmath>
#include<algorithm>
#define MAXN //最大序列长度
#define MAXM //最大操作次数
#define pos(i) (i/sz)
using namespace std;
int ans[MAXM],Ans,sz;
//以及其他辅助数组和数据结构
struct query{//离线询问
int l,r,id;
}q[MAXM];
bool cmp(const query &q1,const query &q2)//排序比较的函数
{
if(pos(q1.l)==pos(q2.l)) return q1.r<q2.r;
else return pos(q1.l)<pos(q2.l);
}
int add(int x)
{
//加入a[x]这个点对答案的影响
}
int del(int x)
{
//删除a[x]这个点对答案的影响
}
int getAns()
{
//通过现有信息计算答案
}
int main()
{
int n,m;scanf("%d%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
sz=sqrt(n+0.5);//避免精度误差
for(int i=1;i<=m;i++)
{
scanf("%d%d",&q[i].l,&q[i].r);
q[i].id=i;//记录编号,因为要按顺序输出
}
sort(q+1,q+m+1,cmp);//对询问离线排序
int L=1,R=0;//初始的区间为空,惯用写法
for(int i=1;i<=m;i++)
{
while(L<q[i].l) del(L++);//删除左端点
while(L>q[i].l) add(--L);//加入左端点
while(R<q[i].r) add(++R);//加入右端点
while(R>q[i].r) del(R--);//删除右端点
ans[q[i].id]=getAns();//此时L=q[i].l,R=q[i].r,用计算出的答案更新ans[q[i].id]。
}
for(int i=1;i<=m;i++)
printf("%d\n",ans[i]);//输出答案
return 0;
}
大概就是这样。