题目描述
对于序列 A A A,它的逆序对数定义为满足 i < j i<j i<j,且 A i > A j A_i>A_j Ai>Aj的数对 ( i , j ) (i,j) (i,j)的个数。
给 1 1 1到 n n n的一个排列,按照某种顺序依次删除 m m m个元素。
你的任务是在每次删除一个元素之前统计整个序列的逆序对数。
输入格式
输入第一行包括 n n n和 m m m,表示初始元素的个数和删除的元素个数。
第二行 n n n个元素表示初始序列。
第三行 m m m个数表示每次要删除元素的值。
输出格式
输出包含 m m m行,表示删除每个元素之前,该序列逆序对的个数。
输入样例
5 4
1 5 3 4 2
5 1 4 2
输出样例
5
2
2
1
数据范围
n ≤ 1 0 5 , m ≤ 5 × 1 0 4 n\leq 10^5,m\leq 5\times10^4 n≤105,m≤5×104
题解
前置知识: c d q cdq cdq分治
删除操作比较麻烦,所以我们可以倒着做,一开始有 n − m n-m n−m个元素,依次加入 m m m个元素,求每加入一个元素之后,该序列逆序对的数量。
设 a i a_i ai表示第 i i i个数的值, t i t_i ti表示第 i i i个数插入序列的时间。先将开始的 n − m n-m n−m个元素插入序列,再将 m m m个元素依次插入序列。
当 i i i插入序列后,序列的逆序对增加的数量为满足以下条件的 j j j的数量:
j < i , a j > a i , t j < t i j<i,a_j>a_i,t_j<t_i j<i,aj>ai,tj<ti或者 j > i , a j < a i , t j < t i j>i,a_j<a_i,t_j<t_i j>i,aj<ai,tj<ti
所以,我们可以先按 j < i , a j > a i , t j < t i j<i,a_j>a_i,t_j<t_i j<i,aj>ai,tj<ti为偏序做一次 c d q cdq cdq,再按 j > i , a j < a i , t j < t i j>i,a_j<a_i,t_j<t_i j>i,aj<ai,tj<ti为偏序做一次 c d q cdq cdq,就可以得出各点加入序列后对逆序对数量的贡献。这里用到了三维偏序 c d q cdq cdq,所以要用 c d q cdq cdq中还要用数状数组来维护答案。
最后按顺序将贡献累加,即可达到答案。时间复杂度为 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)
code
#include<bits/stdc++.h>
using namespace std;
int n,m,vt=0,d[100005],tr[100005];
struct node{
int x,id,t;
long long ans;
}a[100005],b[100005];
bool cmp(node ax,node bx){
return ax.t<bx.t;
}
int lb(int i){
return i&(-i);
}
void add(int i,int t){
while(i<=n){
tr[i]+=t;i+=lb(i);
}
}
int find(int i){
int re=0;
while(i){
re+=tr[i];i-=lb(i);
}
return re;
}
void cdq1(int l,int r){
if(l==r) return;
int mid=(l+r)/2;
cdq1(l,mid);cdq1(mid+1,r);
int i=l,j=mid+1,k=l;
while(i<=mid&&j<=r){
if(a[i].id<a[j].id){
add(a[i].x,1);
b[k]=a[i];++i;++k;
}
else{
a[j].ans+=find(n)-find(a[j].x-1);
b[k]=a[j];++j;++k;
}
}
while(i<=mid){
add(a[i].x,1);b[k]=a[i];++i;++k;
}
while(j<=r){
a[j].ans+=find(n)-find(a[j].x-1);
b[k]=a[j];++j;++k;
}
for(int o=l;o<=mid;o++){
add(a[o].x,-1);
}
for(int o=l;o<=r;o++) a[o]=b[o];
}
void cdq2(int l,int r){
if(l==r) return;
int mid=(l+r)/2;
cdq2(l,mid);cdq2(mid+1,r);
int i=l,j=mid+1,k=l;
while(i<=mid&&j<=r){
if(a[i].x<a[j].x){
add(a[i].id,1);
b[k]=a[i];++i;++k;
}
else{
a[j].ans+=find(n)-find(a[j].id-1);
b[k]=a[j];++j;++k;
}
}
while(i<=mid){
add(a[i].id,1);b[k]=a[i];++i;++k;
}
while(j<=r){
a[j].ans+=find(n)-find(a[j].id-1);
b[k]=a[j];++j;++k;
}
for(int o=l;o<=mid;o++){
add(a[o].id,-1);
}
for(int o=l;o<=r;o++) a[o]=b[o];
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
scanf("%d",&a[i].x);a[i].id=i;
}
for(int i=1,x;i<=m;i++){
scanf("%d",&x);d[x]=m-i+1;
}
for(int i=1;i<=n;i++){
if(!d[a[i].x]) d[a[i].x]=--vt;
a[i].t=d[a[i].x];
}
sort(a+1,a+n+1,cmp);
cdq1(1,n);
sort(a+1,a+n+1,cmp);
cdq2(1,n);
sort(a+1,a+n+1,cmp);
for(int i=2;i<=n;i++) a[i].ans+=a[i-1].ans;
for(int i=n;i>=1;i--){
if(a[i].t>0) printf("%lld\n",a[i].ans);
}
return 0;
}