题目
https://uva.onlinejudge.org/index.php?option=com_onlinejudge&Itemid=8&page=show_problem&problem=3141
题意
一个1到n的排列,每次随机删除一个,问删除前的逆序数
思路
综合考虑,对每个数点,令value为值,pos为位置,time为出现时间(总时间-消失时间),明显是统计value1 > value2, pos1 < pos2, time1 < time2的个数
首先对其中一个轴排序,比如value,这样在归并过程中,左子树的value总是小于右子树的,可以分治。
当左右子树包含哪些数点已经确定后,可以用自下而上的归并排序使得子树上的数点按照第二维相对有序,方便用尺取法统计子树之间的逆序数。
第三维通过树状数组进行压缩,加快统计速度。
注意仅仅统计左子树对右子树的影响,就会错过右子树中的数点出现的比较晚的情况。因此需要统计右子树对左子树的影响,此时注意别把同一时间出现的重复计数。
感想
1. 注意long long!!!
2. BIT的上限要>=n!
3. 注意统计影响完成后需要清空树状数组(区间大小已经减少了所以可以浪费地使用),此时不能直接用memset清空整个数组,时间会成为O(n2),超时。
代码
时间: 0.250s
时间复杂度O(cnlogn)
#include <algorithm> #include <cassert> #include <cmath> #include <cstdio> #include <cstring> #include <queue> #include <tuple> #include <cassert> using namespace std; const int MAXN = int(4e5 + 4); #define LEFT_CHILD(x) ((x) << 1) #define RIGHT_CHILD(x) (((x) << 1) + 1) #define FATHER(x) ((x) >> 1) #define IS_LEFT_CHILD(x) (((x) & 1) == 0) #define IS_RIGHT_CHILD(x) (((x) & 1) == 1) #define BROTHER(x) ((x) ^ 1) #define LOWBIT(x) ((x) & (-x)) #define LOCAL_DEBUG struct Node{ int value, pos, time; }nodes[MAXN], tmpNodes[MAXN]; int timeCnt[MAXN * 4]; long long revNum[MAXN]; int clearStack[MAXN]; int clearLen; int n, m; int bitLimit; int getHigherBit(int n) { int x = 1; while (x < n) { x <<= 1; } return x; } void update(int id) { while (id <= bitLimit) { if (timeCnt[id] == 0) { clearStack[clearLen++] = id; } timeCnt[id]++; id += LOWBIT(id); } } void clearCnt() { while (clearLen > 0) { timeCnt[clearStack[--clearLen]] = 0; } } int countTimesSmaller(int id) { if (id < 0)return 0; int sum = 0; int tmp = 0; while (id > 0) { sum += timeCnt[id]; id -= LOWBIT(id); } return sum; } void merge_by_pos(int root_ind, int internal_l, int internal_r) { int internal_mid = (internal_l + internal_r) >> 1; for (int i = internal_l; i <= internal_r; i++) { tmpNodes[i] = nodes[i]; } for (int i = internal_l, j = internal_mid + 1, ind = internal_l; ind <= internal_r; ) { if (i > internal_mid) { nodes[ind++] = tmpNodes[j++]; } else if (j > internal_r) { nodes[ind++] = tmpNodes[i++]; } else if (tmpNodes[i].pos < tmpNodes[j].pos) { nodes[ind++] = tmpNodes[i++]; } else { nodes[ind++] = tmpNodes[j++]; } } } void cal(int root_ind, int internal_l, int internal_r) { if (internal_l == internal_r)return; int internal_mid = (internal_l + internal_r) >> 1; if(internal_l != internal_mid)cal(LEFT_CHILD(root_ind), internal_l, internal_mid); if (internal_mid + 1 != internal_r)cal(RIGHT_CHILD(root_ind), internal_mid + 1, internal_r); // printf("L Node: %d[%d, %d] LC: %d[%d, %d], RC: %d[%d, %d]\n", root_ind, internal_l, internal_r, LEFT_CHILD(root_ind), internal_l, internal_mid, RIGHT_CHILD(root_ind), internal_mid + 1, internal_r); for (int i = internal_l, j = internal_mid + 1; i <= internal_mid; i++) { while (j <= internal_r && nodes[i].pos > nodes[j].pos) { update(nodes[j].time); j++; } revNum[nodes[i].time] += countTimesSmaller(nodes[i].time); // printf("L (%d, %d, %d): +%d\n", nodes[i].value, nodes[i].pos, nodes[i].time, countTimesSmaller(nodes[i].time)); } clearCnt(); for (int i = internal_mid, j = internal_r; j > internal_mid; j--) { while (i >= internal_l && nodes[i].pos > nodes[j].pos) { update(nodes[i].time); i--; } revNum[nodes[j].time] += countTimesSmaller(nodes[j].time - 1); // printf("R (%d, %d, %d): +%d\n", nodes[j].value, nodes[j].pos, nodes[j].time, countTimesSmaller(nodes[j].time - 1)); } clearCnt(); merge_by_pos(root_ind, internal_l, internal_r); } int main() { #ifdef LOCAL_DEBUG freopen("input.txt", "r", stdin); freopen("output2.txt", "w", stdout); #endif // LOCAL_DEBUG for (int ti = 1; scanf("%d%d", &n, &m) == 2; ti++) { bitLimit = getHigherBit(n); for (int i = 1; i <= n; i++) { int tmp; scanf("%d", &tmp); nodes[tmp].value = tmp; nodes[tmp].pos = i; nodes[tmp].time = 1; } for (int i = 1; i <= m + 1; i++) { revNum[i] = 0; } for (int i = 0; i < m; i++) { int tmp; scanf("%d", &tmp); nodes[tmp].time = m - i + 1; } cal(1, 1, n); long long ans = 0; for (int i = 1; i <= m + 1; i++) { ans += revNum[i]; } for (int i = 0; i < m; i++) { printf("%lld\n", ans); ans -= revNum[m - i + 1]; } } return 0; }