问题
给定整数m以及n各数字A1,A2,..An,将数列A中所有元素两两异或,共能得到n(n-1)/2个结果,请求出这些结果中大于m的有多少个。
输入描述:
第一行包含两个整数n,m.
第二行给出n个整数A1,A2,…,An。
数据范围
对于30%的数据,1 <= n, m <= 1000
对于100%的数据,1 <= n, m, Ai <= 10^5
输出描述:
输出仅包括一行,即所求的答案
思路
使用字典树(TrieTree)从高位到低位建立字典,
再使用每个元素依次去字典中查对应高位异或结果。
若m对应位置为1, 则当前元素在该位的异或也必须为1;
若m对应位置为0,则加上与当前元素异或结果为1的元素个数;
将所有元素查找后的结果相加,然后再除以2,就是最终的结果。
代码
#include<iostream>
using namespace std;
const int N = 1e6 + 10;
int a[N];
struct TrieNode{
int count;
TrieNode* next[2];
TrieNode(){
count = 0;
next[0] = NULL;
next[1] = NULL;
}
};
void Insert(TrieNode* root, int value){
TrieNode* p = root;
for(int i=31; i>=0; i--) {
int temp = (value >> i) & 1;
if (p->next[temp] == NULL){
p->next[temp] = new TrieNode();
}
p = p->next[temp];
p->count++;
}
}
long long Find(TrieNode* root, int value, int m) {
TrieNode *p = root;
long long result = 0;
for(int i=31; i>=0; i--) {
int val = (value>>i) & 1;
int mval = (m>>i) & 1;
if(mval == 1) {
p = p->next[val^1];
} else {
if (p->next[val^1] != NULL) {
result += p->next[val^1]->count;
}
p = p->next[val];
}
if (p == NULL)
break;
}
return result;
}
int main(){
int m,n;
long long result = 0;
cin>>n>>m;
TrieNode* root = new TrieNode();
for(int i=0; i<n; i++) {
cin>>a[i];
Insert(root, a[i]);
}
for(int i=0; i<n; i++) {
result += Find(root, a[i], m);
}
cout<<(result>>1)<<endl;
return 0;
}