BZOJ 2212
从下到上线段树合并。
考虑到每一个子树内部产生的贡献不可能通过换儿子消除,所以一次更换只要看看把哪个儿子放在左边产生的逆序对数少就可以了。
逆序对数可以在线段树合并的时候顺便算出来。
由于只有叶子结点有权值 + 二叉树的特性,大大方便了这道题的代码和细节处理。
注意点数总共要开到$2 * n$。
时间复杂度$O(nlogn)$。
Code:
#include <cstdio> #include <cstring> using namespace std; typedef long long ll; const int N = 4e5 + 5; int m, n = 0, rt = 0; ll ans = 0LL; struct Node { int lc, rc, w; } a[N]; inline void read(int &X) { X = 0; char ch = 0; int op = 1; for(; ch > '9' || ch < '0'; ch = getchar()) if(ch == '-') op = -1; for(; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } inline ll min(ll x, ll y) { return x > y ? y : x; } void build(int &now) { now = ++n; read(a[now].w); if(a[now].w) return; build(a[now].lc), build(a[now].rc); } namespace SegT { struct Node { int lc, rc; ll siz; } s[N * 30]; int sta[N * 30], top = 0, root[N], nodeCnt = 0; ll res1, res2; inline void push(int now) { sta[++top] = now; } inline int newNode() { if(top) return sta[top--]; else return ++nodeCnt; } #define lc(p) s[p].lc #define rc(p) s[p].rc #define siz(p) s[p].siz #define mid ((l + r) >> 1) inline void up(int p) { if(!p) return; siz(p) = siz(lc(p)) + siz(rc(p)); } void ins(int &p, int l, int r, int x) { if(!p) p = newNode(); ++siz(p); if(l == r) return; if(x <= mid) ins(lc(p), l, mid, x); else ins(rc(p), mid + 1, r, x); } int merge(int u, int v, int l, int r) { if(!u || !v) return u + v; res1 += siz(rc(u)) * siz(lc(v)); res2 += siz(rc(v)) * siz(lc(u)); int p = newNode(); if(l == r) siz(p) = siz(u) + siz(v); else { lc(p) = merge(lc(u), lc(v), l, mid); rc(p) = merge(rc(u), rc(v), mid + 1, r); up(p); } push(u), push(v); return p; } void print(int p, int l, int r) { if(l == r) { printf("%lld", siz(p)); return; } print(lc(p), l, mid), print(rc(p), mid + 1, r); } inline void deb(int x) { print(root[x], 1, m); } #undef lc #undef rc #undef mid #undef siz } using namespace SegT; void solve(int now) { if(a[now].w) return; solve(a[now].lc), solve(a[now].rc); res1 = res2 = 0LL; root[now] = merge(root[a[now].lc], root[a[now].rc], 1, m); ans += min(res1, res2); } int main() { read(m); build(rt); /* for(int i = 1; i <= n; i++) printf("%d %d %d\n", a[i].lc, a[i].rc, a[i].w); printf("\n"); */ for(int i = 1; i <= n; i++) if(a[i].w) ins(root[i], 1, m, a[i].w); /* for(int i = 1; i <= n; i++) { if(!a[i].w) continue; deb(i); printf("\n"); } */ solve(rt); printf("%lld\n", ans); return 0; }