链接
http://acm.hdu.edu.cn/showproblem.php?pid=6059
题意
给出 A[1..n] , (1<=∑n<=5∗105) , (0<=A[i]<230) ,要求统计三元组 (i,j,k) 的个数使其满足 i<j<k 并且 (A[i]xorA[j])<(A[j]xorA[k]) 。
思路
事先把所有数字插入字典树中,用字典树维护
A[k]
的信息,接着对每一个
A[i]
,枚举其二进制最高位小于
A[k]
的位数,考虑这样一个情况:若当前枚举到了
A[i]
的二进制第5位比
A[k]
小,那么
A[i]
与
A[k]
的第30位到第6位都是相同的,此时就不用考虑
A[j]
的第30位到第6位如何,只考虑第5位的情况就好。
考虑当前位置
A[i]
的情况,若当前位置
A[i]
为0,那么
A[j]
的相同位置要为0才能使两者异或值为0,此时
A[k]
为1,这时满足条件的
A[j]
与
A[k]
对数可以计入答案。当前位置
A[i]
为1的情况同理(
A[j]
为1,
A[k]
为0)。
对于
A[j]
与
A[k]
对数的统计,在插入
A[k]
时,之前插入的数就都成了
A[j]
。因此用一个cnt[i][j]
数组记录下第i位为j的数之前出现了几次,那么在插入时,对于这一位置
A[k]
为0的情况,之前有多少的
A[j]
在这一位为1,就是此时满足条件的
A[j]
的个数。代码里的cnt[i][nxt ^ 1]
就是此时符合条件的
A[j]
个数。
当我们把一个数从字典树中去掉时,也要考虑去掉这个数留下来的统计值。
这题特殊的地方在于,插入是连续的,之后是连续的删除,所以在插入完成后可以把cnt[i][j]
数组清空一遍,用来记录第i位为j的数被删除了几次。
考虑两个方面:
- 一个是这个数作为
A[k]
直接被去掉带来的影响,像之前一样减去其前面已经被删去的
A[j]
的个数(依然是cnt[now][nxt ^ 1]
)就好。(这一步操作在Trie::Insert()
里面,与插入时的操作类似)
- 还有一个是这个数作为
A[j]
带来的影响,因为这个数已经不能和后面的
A[k]
组合产生贡献了,考虑到在统计时,当前位的
A[k]
已经把可以与其组合的
A[j]
个数统计在了sum[tmp]
中,这里面还需去掉被删去的
A[j]
,被删去的
A[j]
已经被统计在了cnt[i][nxt]
中,现有的
A[k]
被存在了val[tmp]
中,这一部分不能被计入答案,相乘,减去。(sum[tmp] - val[tmp] * cnt[i][nxt]
这一步在函数solve()
里)
希望思路说清楚了,详见代码
代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define MS(x, y) memset(x, y, sizeof(x))
typedef long long LL;
const int MAXN = 5e5 + 5;
int bits[32];
struct Trie {
int tot, root;
int val[MAXN * 30], ch[MAXN * 30][2];
LL sum[MAXN * 30], cnt[MAXN][2];
int newnode() {
val[tot] = sum[tot] = 0;
ch[tot][0] = ch[tot][1] = -1;
return tot++;
}
void init() {
tot = 0;
root = newnode();
MS(cnt, 0);
}
void Insert(int x, int v) {
int now = root, nxt, tmp;
for (int i = 30; i >= 0; --i) {
nxt = !!(x & bits[i]);
if (ch[now][nxt] == -1) ch[now][nxt] = newnode();
now = ch[now][nxt];
++cnt[i][nxt];
sum[now] += v * cnt[i][nxt ^ 1];
val[now] += v;
}
}
LL solve(int x) {
LL ret = 0;
int now = root, tmp, nxt;
for (int i = 30; i >= 0; --i) {
nxt = !!(x & bits[i]);
tmp = ch[now][nxt ^ 1];
now = ch[now][nxt];
if (tmp != -1) {
ret += sum[tmp] - val[tmp] * cnt[i][nxt];
}
if (now == -1) break;
}
return ret;
}
};
int n;
int a[MAXN];
LL ans;
Trie trie;
int main() {
bits[0] = 1;
for (int i = 1; i < 32; ++i) bits[i] = bits[i - 1] << 1;
int T;
scanf("%d", &T);
while (T--) {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) scanf("%d", a + i);
ans = 0;
trie.init();
for (int i = 1; i <= n; ++i) trie.Insert(a[i], 1);
MS(trie.cnt, 0);
for (int i = 1; i < n; ++i) {
trie.Insert(a[i], -1);
ans += trie.solve(a[i]);
}
printf("%I64d\n", ans);
}
}