codeforces 785E (树状数组套平衡树)

题目链接:点击这里

题意:动态逆序对问题。一个[1,2,3n]的数组,每次操作是交换两个元素,输出交换后的逆序对数。

需要计算的是交换的两个数,在它们中间的数中分别有多少数比他们大(小)。利用树状数组的思想,把下标为i的树用第i,i+lowbit(i),i+lowbit(i)+lowbit(i+lowbit(i))...棵平衡树维护。然后只需要用前缀减减的思想就好了。逆序对的改变xjb搞搞即可。

#include <cstdio>
#include <iostream>
#include <cstring>
#include <queue>
#include <cmath>
#include <algorithm>
#include <stack>
#include <map>
#include <string>
#include <set>
#include <stdlib.h>
#define Clear(x,y) memset (x,y,sizeof(x))
#define Close() ios::sync_with_stdio(0)
#define Open() freopen ("more.in", "r", stdin)
#define get_min(a,b) a = min (a, b)
#define get_max(a,b) a = max (a, b);
#define y0 yzz
#define y1 yzzz
#define fi first
#define se second
#define pii pair<int, int>
#define pli pair<long long, int>
#define pll pair<long long, long long>
#define pdi pair<double, int>
#define pdd pair<double, double>
#define pb push_back
#define pl c<<1
#define pr (c<<1)|1
#define lson l,mid,pl
#define rson mid+1,r,pr
typedef unsigned long long ull;
template <class T> inline T lowbit (T x) {return x&(-x);}
template <class T> inline T sqr (T x) {return x*x;}
template <class T>
inline bool scan (T &ret) {
    char c;
    int sgn;
    if (c = getchar(), c == EOF) return 0; //EOF
    while (c != '-' && (c < '0' || c > '9') ) c = getchar();
    sgn = (c == '-') ? -1 : 1;
    ret = (c == '-') ? 0 : (c - '0');
    while (c = getchar(), c >= '0' && c <= '9') ret = ret * 10 + (c - '0');
    ret *= sgn;
    return 1;
}
const double pi = 3.14159265358979323846264338327950288L;
using namespace std;    
#define mod 1000000007
#define INF 1e9
#define maxn 200005
#define maxm maxn*20
//-----------------morejarphone--------------------//


struct node *null; //安全节点
struct node {
    node *ch[2];
    int v, r, s;
    int cmp(int x) const {
        if(x == v) return -1;
        return x < v? 0: 1;
    }
    void maintain() {
        s = ch[0]->s + ch[1]->s + 1;
    }
}nodes[maxn*30], *root[maxn];
int cnt;

void rot(node *&o, int d) { //旋转,d为0表示左旋,为1表示右旋
    node *p = o->ch[d^1]; o->ch[d^1] = p->ch[d]; p->ch[d] = o;
    o->maintain(); p->maintain(); o = p;
}

void ins(node *&o, int x) { //插入值为x的节点
    if(o == null) {
        o = &nodes[cnt++];
        o->ch[0] = o->ch[1] = null;
        o->v = x; o->r = rand(); o->s = 1;
    } else {
        int d = o->cmp(x);
        ins(o->ch[d], x); if(o->ch[d]->r > o->r) rot(o, d^1);
    }
    o->maintain();
}

void del(node *&o, int x) { //删掉值为x的节点,调用前保证节点存在
    int d = o->cmp(x);
    if(d == -1) {
        if(o->ch[0] == null) o = o->ch[1];
        else if(o->ch[1] == null) o = o->ch[0];
        else {
            int d2 = o->ch[0]->r > o->ch[1]->r? 1: 0;
            rot(o, d2); del(o->ch[d2], x);
        }
    } else {
        del(o->ch[d], x);
    }
    if(o != null) o->maintain();
}

int mcount(node *p, int v) { //返回以p为根的名次树中大于v的个数
    if(p == null) return 0;
    if (p->v > v) return 1+p->ch[1]->s+mcount (p->ch[0], v);
    else return mcount (p->ch[1], v);
}

int n, m, a[maxn];

void init () {
    srand ((long long)time (0));
    null = &nodes[0]; null->s = 0;
    cnt = 1;
    for (int i = 1; i <= n; i++) { 
        root[i] = null;
        a[i] = i;
    } 
    for (int i = 1; i <= n; i++) {
        for (int j = i; j <= n; j += lowbit (j)) ins (root[j], a[i]);
    }
}

long long sum = 0;

int query (int x, int val) {//查询x子树中大于val的个数
    int ans = 0;
    for (int i = x; i; i -= lowbit (i)) {
        ans += mcount (root[i], val);
    }   return ans;
}

int main () {
    cin >> n >> m; 
    init (); 
    for (int i = 0; i < m; i++) {
        int l, r; scan (l); scan (r);
        if (r < l) swap (l, r);
        if (l == r) {
            printf ("%lld\n", sum);
            continue;
        }
        int tmp1 = query (l, a[l]);
        int tmp2 = query (r, a[r]);
        if (a[r] < a[l]) tmp2--;
        for (int i = l; i <= n; i += lowbit (i)) del (root[i], a[l]);
        for (int i = r; i <= n; i += lowbit (i)) del (root[i], a[r]); 
        for (int i = l; i <= n; i += lowbit (i)) ins (root[i], a[r]);
        for (int i = r; i <= n; i += lowbit (i)) ins (root[i], a[l]);
        int tmp3 = query (l, a[r]);
        int tmp4 = query (r, a[l]);
        if (a[r] > a[l]) tmp4--;
        sum += ((tmp4-tmp1)-(r-l-1-(tmp4-tmp1))+r-l-1-2*(tmp2-tmp3));
        if (a[r] < a[l]) sum--;
        else sum++;
        swap (a[l], a[r]);
        printf ("%lld\n", sum);
    }
    return 0;
}
阅读更多
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/morejarphone/article/details/64126515
个人分类: 平衡树 树状数组
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

关闭
关闭
关闭