HDU 5193
题意:给出n个数的序列a,m个操作。
操作1:[x,y] 将y插入到第x个人之后.
操作2:[x],将第x个人删除(x+1,..n向前进一格).
n,m,a[i]<=2e4. 问每次操作后序列a的逆序对(i,j)有多少? (i<j && a[i]>a[j] ).
假如当前逆序对为res,那么插入一个数y之后 要知道[x+1,n]有多个比y小,[1..x-1]直接有多个比y大.
插入,删除操作 如何处理下标?
此时用到一个叫 块状链表的东西,链表中每个元素是一个数组,
数组大小最多为2sqrt(n) 若超过2sqrt(n),则用到分裂操作.
若相邻两个表 元素个数<=sqrt(n) 则合并这两个表. 块状链表的插入和删除操作都是sqrt(n)滴。
现在对链表中的每一块,套一个树状数组.
x之后有多少个比x大 则链表往后走 每走一个向其BIT查询大于x的个数.
x之前有多少个比x小 则链表往前走.查询每块中小于x的个数.
然后对于块内的元素,sqrt(n)暴力查询即可.
#include <bits/stdc++.h>
using namespace std;
typedef pair<int,int> ii;
const int N=2e4+5,m=320;
int lowbit(int x){return x&-x;}
void add(int c[],int x,int val)
{
for(int i=x;i<N;i+=lowbit(i))
c[i]+=val;
}
int sum(int c[],int l,int r)
{
int sum1=0,sum2=0;
while(l>0)
{
sum1+=c[l];
l-=lowbit(l);
}
while(r>0)
{
sum2+=c[r];
r-=lowbit(r);
}
return sum2-sum1;
}
struct data{
int s,a[N*2];
data *next;
int c[N];
data()
{
memset(c,0,sizeof(c));
next=NULL;
}
};
data *root;
void insert(int x,int pos)
{
if(root==NULL)
{
root=new data;
root->s=1;
root->a[1]=x;
add(root->c,x,1);//
return;
}
data *k=root;
while(pos> k->s && k->next!=NULL)
{
pos-=k->s;
k=k->next;
}
memmove(k->a+pos+1,k->a+pos,sizeof(int)*(k->s-pos+1));
k->s++;
k->a[pos]=x;
add(k->c,x,1);
//split
if(k->s==2*m)
{
data *t=new data;
t->next=k->next;
k->next=t;
memcpy(t->a+1,k->a+m+1,sizeof(int)*m);
for(int i=1;i<=m;i++)
{
add(k->c,t->a[i],-1);
add(t->c,t->a[i],1);
}
t->s=k->s=m;
}
}
int find(int pos)
{
data *k=root;
while(pos>k->s && k->next!=NULL)
{
pos-=k->s;
k=k->next;
}
return k->a[pos];
}
int work(int pos)
{
int res=0;
data *k=root;
int x=find(pos);
while(pos>k->s && k->next!=NULL)
{
pos-=k->s;
res+=sum(k->c,x,N);//large than x
k=k->next;
}
for(int i=1;i<pos;i++)
if(k->a[i]>x)
res++;
for(int i=pos+1;i<=k->s;i++)
if(k->a[i]<x)
res++;
while(k->next!=NULL)
{
k=k->next;
res+=sum(k->c,0,x-1);
}
return res;
}
void destroy(data *k)
{
if(k->next!=NULL)
destroy(k->next);
delete k;
}
void del(int pos)
{
data *k=root;
while(pos>k->s&&k->next!=NULL)
{
pos-=k->s;
k=k->next;
}
add(k->c,k->a[pos],-1);
memmove(k->a+pos,k->a+pos+1,sizeof(int)*(k->s -pos));
k->s--;
}
int main()
{
int n,p;
while(~scanf("%d%d",&n,&p))
{
root=NULL;
int ans=0,x;
for(int i=1;i<=n;i++)
{
scanf("%d",&x);
insert(x,i);
ans+=work(i);
}
while(p--)
{
int q,x,y;
scanf("%d",&q);
if(q==0)
{
scanf("%d %d",&x,&y);
x++;
insert(y,x);
ans+=work(x);
}
else
{
scanf("%d",&x);
ans-=work(x);
del(x);
}
printf("%d\n",ans);
}
destroy(root);
}
return 0;
}