题目
思路: 每个删除的元素有一个删除时间t,位置pos,值val。
对于每个元素删除,总序列减少的逆序对的个数为:
- t1>t0,pos1<pos0,val1>val0;
- t1>to,pos1>pos0,val1<val0;
就可以转化为三维偏序问题。先对时间从大到小排序。
cdq分治时对第二维pos排序,然后统计左边对右边的影响,左边无论删不删出的元素都要add(q[i].val)到树状数组中去,因为删除时间小于右边的啊,会对右边有影响的啊。右边统计答案的时候直接这个元素会被删除再统计就好了,统计也没必要啊嘻嘻嘻。
还有最后最后输出答案的时候。注意我们这里cdq是统计删除每个元素减少的逆序对的个数,题目问的是删除某个元素前逆序对的个数,那就是删除前一个元素前逆序对的个数减去cdq求出来的删除前一个元素减少的逆序对个数。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstring>
#include<algorithm>
#define en '\n'
#define low(x) ((x)&(-x))
using namespace std;
typedef long long ll;
template<class T>void rd(T &x)
{
x=0;int f=0;char ch=getchar();
while(ch<'0'||ch>'9') {f|=(ch=='-');ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
x=f?-x:x;
return;
}
const int N=1e5+5,M=5e4+5,INF=0x3f3f3f3f;
struct node{int t,pos,val,id;}q1[N],q2[N],tmp[N];
int pos[N];
int cmp(node x,node y){
return x.t>y.t;
}
int ans[M],c[N],n;
void add(int x){
while(x<=n) ++c[x],x+=low(x);
}
int getsum(int x){
int res=0;
while(x) res+=c[x],x-=low(x);
return res;
}
void _clear(int x){
while(x<=n)
{
if(!c[x]) break;
c[x]=0,x+=low(x);
}
}
void cdq1(int l,int r)
{
if(l==r) return;
int mid=(l+r)>>1;
cdq1(l,mid),cdq1(mid+1,r);
int i=l,j=mid+1,k=l;
while(i<=mid&&j<=r)
{
if(q1[i].pos<q1[j].pos) add(q1[i].val),tmp[k++]=q1[i++];
else
{
if(q1[j].id) ans[q1[j].id]+=getsum(n)-getsum(q1[j].val-1);
tmp[k++]=q1[j++];
}
}
while(i<=mid) add(q1[i].val),tmp[k++]=q1[i++];
while(j<=r)
{
if(q1[j].id) ans[q1[j].id]+=getsum(n)-getsum(q1[j].val);
tmp[k++]=q1[j++];
}
for(int i=l;i<=mid;++i) _clear(q1[i].val),q1[i]=tmp[i];
for(int j=mid+1;j<=r;++j) q1[j]=tmp[j];
}
void cdq2(int l,int r)
{
if(l==r) return;
int mid=(l+r)>>1;
cdq2(l,mid),cdq2(mid+1,r);
int i=l,j=mid+1,k=l;
while(i<=mid&&j<=r)
{
if(q2[i].pos>q2[j].pos) add(q2[i].val),tmp[k++]=q2[i++];
else
{
if(q2[j].id) ans[q2[j].id]+=getsum(q2[j].val);
tmp[k++]=q2[j++];
}
}
while(i<=mid) add(q2[i].val),tmp[k++]=q2[i++];
while(j<=r)
{
if(q2[j].id) ans[q2[j].id]+=getsum(q2[j].val);
tmp[k++]=q2[j++];
}
for(int i=l;i<=mid+1;++i) _clear(q2[i].val),q2[i]=tmp[i];
for(int j=mid+1;j<=r;++j) q2[j]=tmp[j];
}
int main()
{
int m,tot=0;rd(n),rd(m);
for(int i=1;i<=n;++i)
{
int x;rd(x);
q1[i]=(node){INF,i,x,0},pos[x]=i;
}
for(int i=1;i<=m;++i)
{
int x;rd(x);
q1[pos[x]].t=i,q1[pos[x]].id=i;
}
ll sum=0;
for(int i=1;i<=n;++i) sum+=getsum(n)-getsum(q1[i].val-1),add(q1[i].val);
for(int i=1;i<=n;++i) _clear(q1[i].val);
sort(q1+1,q1+n+1,cmp);
for(int i=1;i<=n;++i) q2[i]=q1[i];
cdq1(1,n);cdq2(1,n);
for(int i=1;i<=m;++i) printf("%lld\n",sum),sum-=ans[i];
}