You are given a permutation {1,2,3,...,n}. Remove m of them one by one, and output the number of inversion pairs before each removal. The number of inversion pairs of an array A is the number of ordered pairs (i,j) such that i < j and A[i] > A[j].
Input
The input contains several test cases. The first line of each case contains two integers n and m (1<=n<=200,000, 1<=m<=100,000). After that, n lines follow, representing the initial permutation. Then m lines follow, representing the removed integers, in the order of the removals. No integer will be removed twice. The input is terminated by end-of-file (EOF). The size of input file does not exceed 5MB.
Output
For each removal, output the number of inversion pairs before it.
Sample Input
5 4 1 5 3 4 2 5 1 4 2
Output for the Sample Input
5 2 2 1
题意:给出一个1~n的排列A,要求按照某种顺序删除一些数(其他数顺序不变),输出每次删除之前逆序对的数目。
思路:O(nlogn)求出初始逆序数。然后树状数组套静态BST。
树状数组的每个元素为一棵静态BST。删除一个元素,只要计算出该元素所贡献的逆序数则可得到删除后的逆序。
也就是前面比它大的和后面比它小的。只需在树状数组的对应的BST上跑即可统计出来。
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> using namespace std; #define maxn 200080 #define maxm 18000000 #define LL long long int int lson[maxm],rson[maxm],key[maxm],vis[maxm],Size[maxm]; int a[maxn],b[maxn],ope[maxn],root[maxn],Pos[maxn];//Pos[maxn]表示每个整数的位置 int cnt; void init() { cnt = 0; Size[0] = lson[0] = rson[0] = vis[0] = 0; } int lowbit(int x) { return x & (-x); } void build(int pos,int l,int r) { if(l > r) return; int mid = (l+r) >> 1; Size[pos] = 1; key[pos] = b[mid]; vis[pos] = 1; if(l < mid) { lson[pos] = ++cnt; build(cnt,l,mid-1); } else lson[pos] = 0; if(r > mid) { rson[pos] = ++cnt; build(cnt,mid+1,r); } else rson[pos] = 0; Size[pos] += Size[lson[pos]]; Size[pos] += Size[rson[pos]]; } void Build(int n) { for(int i = 1;i <= n;i++) { int u = i-lowbit(i)+1,v = i; for(int j = u;j <= v;j++) { b[j-u+1] = a[j]; } sort(b+1,b+1+v-u+1); root[i] = ++cnt; key[cnt] = b[(1+v-u+1)/2]; build(cnt,1,v-u+1); } } struct ST { int l,r,sum; }st[maxn<<2]; void buildtree(int id,int l,int r) { st[id].l = l,st[id].r = r; st[id].sum = 0; if(l == r) return; int mid = (l+r) >> 1; buildtree(id<<1,l,mid); buildtree(id<<1|1,mid+1,r); } void PushUp(int id) { st[id].sum = st[id<<1].sum + st[id<<1|1].sum; } void update(int id,int pos) { if(st[id].l == pos && st[id].r == pos) { st[id].sum = 1; return; } if(st[id<<1].r >= pos) update(id<<1,pos); else update(id<<1|1,pos); PushUp(id); } int query(int id,int l,int r) { if(st[id].l == l && st[id].r == r) return st[id].sum; if(st[id<<1].r >= r) return query(id<<1,l,r); else if(st[id<<1|1].l <= l) return query(id<<1|1,l,r); else return query(id<<1,l,st[id<<1].r) + query(id<<1|1,st[id<<1|1].l,r); } void remove(int pos,int k) { if(!pos) return; Size[pos]--; if(key[pos] == k) vis[pos] = 0; else if(key[pos] > k) remove(lson[pos],k); else remove(rson[pos],k); } void gao(int pos,int n) { for(int i = pos;i <= n;i += lowbit(i)) { int rot = root[i];//树的根 remove(rot,a[pos]); } } int LeftMore(int pos,int k)//左边比我大的数有多少个 { if(!pos) return 0; if(key[pos] == k) return Size[rson[pos]]; else if(key[pos] > k) return vis[pos] + Size[rson[pos]] + LeftMore(lson[pos],k); else return LeftMore(rson[pos],k); } int RightLess(int pos,int k) { if(!pos) return 0; if(key[pos] == k) return Size[lson[pos]]; else if(key[pos] > k) return RightLess(lson[pos],k); else return Size[lson[pos]] + vis[pos] + RightLess(rson[pos],k); } int count(int pos,int n) { int sum = 0,lsum = 0; for(int i = pos;i > 0;i -= lowbit(i)) { int rot = root[i]; lsum += Size[rot]; sum += LeftMore(rot,a[pos]); } lsum -= sum; for(int i = n;i > 0;i -= lowbit(i)) { int rot = root[i]; sum += RightLess(rot,a[pos]); } sum -= lsum; return sum; } int main() { //freopen("in.txt","r",stdin); int n,m; while(scanf("%d%d",&n,&m)==2) { init(); for(int i = 1;i <= n;i++) { scanf("%d",&a[i]); Pos[a[i]] = i; } for(int i = 1;i <= m;i++) scanf("%d",&ope[i]); Build(n);//建好树了 //求出初始逆序数 buildtree(1,1,n); LL ans = 0; for(int i = 1;i <= n;i++) { ans += query(1,a[i],n); update(1,a[i]); } //逆序数求出来了,接下来就是删除了。 for(int i = 1;i <= m;i++) { int pos = Pos[ope[i]];//删除第几个数 gao(pos,n);//把这个数删除了 //记下来得计算这个数贡献的逆序数,然后减去 printf("%lld\n",ans); ans -= count(pos,n); } } return 0; }