测试地址:动态逆序对
做法:本人这几天学习了CDQ分治思想,感觉还是比较难懂,于是找到了比较好理解的经典应用——三维偏序问题来加深理解。
这题首先需要把问题转化为三维偏序问题,然后再使用CDQ分治解决。
首先这个题目是将元素一个一个删除,在每次删除之前询问逆序对数,从这个方面来看好像无法下手,那么我们不如反过来,看成是将元素一个一个插入,在每次插入之后询问逆序对数。那么每个元素我们就可以使用一个三维坐标
(xi,yi,zi)
来表示,其中
xi
指元素的插入时间(以插入先后顺序标号为1~M,一开始就在的标号为0),
yi
指元素在排列中的位置,
zi
指元素的值。那么对于一个点
(xi,yi,zi)
,如果存在
newi
个点
(xj,yj,zj)
使得
xi≤xj
,
yi≤yj
且
zi≥zj
或
xi≤xj
,
yi≥yj
且
zi≤zj
,那么在第
xi
次插入之后逆序对数就会增加
newi
个(想一想,为什么?)。于是我们就得到了一个变形的三维偏序问题,我们需要想办法求出所有的
newi
。
由于N达到100000,所以
O(N2)
的暴力是绝对炸的。网上有人讲解三维偏序问题时说了一句精辟的话:一维排序,二维分治,三维数据结构。按照这个思路,我们首先把所有点按
x
从小到大排序,重新标号为1~N,然后分治。这里使用的分治方法是CDQ分治,CDQ分治是一种思想,包含递归处理左半、处理左半对右半的影响、递归处理右半三个步骤。以下只考虑怎么处理左半对右半的影响。
假设我们在处理一个区间
经过证明,以上方法的时间复杂度为
O(Nlog2N)
,可以通过全部数据。注意每次处理完后清空树状数组时不要鲁莽地使用memset,会TLE,应该按照原来的顺序再把加上的东西都给减掉。除此之外,要注意排序和处理的顺序,因为有时排序会破坏掉原来的顺序。
以下是本人代码:
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define ll long long
using namespace std;
int n,m,pos[100010]={0};
ll ans[50010]={0},bit[100010]={0};
struct point3D
{
int x,y,z,id;
}p[100010];
bool cmpx(point3D a,point3D b) {return a.x<b.x;}
bool cmpy1(point3D a,point3D b) {return a.y<b.y;}
bool cmpy2(point3D a,point3D b) {return a.y>b.y;}
bool cmpid(point3D a,point3D b) {return a.id<b.id;}
int lowbit(int x)
{
return x&(-x);
}
void add(int x,ll d)
{
for(int i=x;i<=n;i+=lowbit(i))
bit[i]+=d;
}
ll query(int x)
{
ll s=0;
while(x)
{
s+=bit[x];
x-=lowbit(x);
}
return s;
}
ll sum(int l,int r)
{
return query(r)-query(l-1);
}
void solve(int l,int r)
{
int mid=(l+r)>>1;
if (l==r) return;
solve(l,mid);
int h;
sort(p+l,p+mid+1,cmpy1);
sort(p+mid+1,p+r+1,cmpy1);
h=l;
for(int i=mid+1;i<=r;i++)
{
while(h<=mid&&p[h].y<=p[i].y) add(p[h].z,1),h++;
ans[m-p[i].x+1]+=sum(p[i].z,n);
}
for(int i=l;i<h;i++) add(p[i].z,-1);
sort(p+l,p+mid+1,cmpy2);
sort(p+mid+1,p+r+1,cmpy2);
h=l;
for(int i=mid+1;i<=r;i++)
{
while(h<=mid&&p[h].y>=p[i].y) add(p[h].z,1),h++;
ans[m-p[i].x+1]+=sum(1,p[i].z);
}
for(int i=l;i<h;i++) add(p[i].z,-1);
sort(p+l+1,p+r+1,cmpid);
solve(mid+1,r);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
p[i].y=i;
scanf("%d",&p[i].z);
}
for(int i=1;i<=m;i++)
{
int a;
scanf("%d",&a);
pos[a]=m-i+1;
}
for(int i=1;i<=n;i++)
p[i].x=pos[p[i].z];
sort(p+1,p+n+1,cmpx);
for(int i=1;i<=n;i++) p[i].id=i;
solve(1,n);
for(int i=m;i>=1;i--)
ans[i]+=ans[i+1];
for(int i=1;i<=m;i++)
printf("%lld\n",ans[i]);
return 0;
}