题目大意:给你n个数,两两互相异或出 n ∗ ( n − 1 ) / 2 n*(n-1)/2 n∗(n−1)/2个数,求前k小的数
第一次知道trie树可以查询异或值得第k小
解题思路有点像这道题:题目门
我们可以用优先队列维护每次先整出每个数异或的第二小值,[因为第一小是自己异或自己],再继续把第二小踢出堆,将这个元素的rk++再回到trie数里面查询第三小的类推下去直到找到第k个
#include <iostream>
#include <cstdio>
#include <stack>
#include <sstream>
#include <limits.h>
#include <vector>
#include <map>
#include <cstring>
#include <deque>
#include <cmath>
#include <iomanip>
#include <queue>
#include <algorithm>
#include <set>
#define Mid ((l + r) >> 1)
#define Lson rt << 1, l , mid
#define Rson rt << 1|1, mid + 1, r
#define ms(a,al) memset(a,al,sizeof(a))
#define log2(a) log(a)/log(2)
#define _for(i,a,b) for( int i = (a); i < (b); ++i)
#define _rep(i,a,b) for( int i = (a); i <= (b); ++i)
#define for_(i,a,b) for( int i = (a); i >= (b); -- i)
#define rep_(i,a,b) for( int i = (a); i > (b); -- i)
#define lowbit(x) ((-x) & x)
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define INF 0x3f3f3f3f
#define LLF 0x3f3f3f3f3f3f3f3f
#define hash Hash
#define next Next
#define pb push_back
#define f first
#define s second
using namespace std;
const int N = 1e5 + 10, MOD = 998244353;
const long double eps = 1e-5;
const int p = 2333;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PII;
typedef pair<ll,ll> PLL;
typedef pair<double,double> PDD;
template<typename T> void read(T &x)
{
x = 0;char ch = getchar();ll f = 1;
while(!isdigit(ch)){if(ch == '-')f*=-1;ch=getchar();}
while(isdigit(ch)){x = x*10+ch-48;ch=getchar();}x*=f;
}
template<typename T, typename... Args> void read(T &first, Args& ... args)
{
read(first);
read(args...);
}
int tr[N * 40][2], siz[N * 40];
int idx;
int a[N];
struct node {
int x, ans, id;
bool operator < (node a) const {
return x > a.x;
}
};
priority_queue<node> heap;
void insert(int x)
{
int rt = 0;
for(int i = 30; i >= 0; -- i)
{
int ch = (x >> i & 1);
if(!tr[rt][ch]) tr[rt][ch] = ++ idx;
rt = tr[rt][ch];
siz[rt] ++;
}
}
int ask(int x, int k)
{
int rt = 0, res = 0;
for(int i = 30; i >= 0; -- i)
{
int ch = x >> i & 1;
if(siz[tr[rt][ch]] >= k) rt = tr[rt][ch];
else k -= siz[tr[rt][ch]], res += 1 << i, rt = tr[rt][ch ^ 1];
}
return res;
}
int n, m;
int main()
{
read(n,m);
for(int i = 1; i <= n; ++ i)
{
read(a[i]);
insert(a[i]);
}
for(int i= 1; i <= n; ++ i)
{
node x;
x.x = ask(a[i],2);
x.id = 2;
x.ans = a[i];
heap.push(x);
}
for(int i = 1; i <= m * 2; ++ i)
{
node x = heap.top();
heap.pop();
if(i & 1) cout << x.x << " ";
if(x.id == n) continue;//rk到n了就不行了最多是n对于一个数
x.id ++, x.x = ask(x.ans,x.id);
heap.push(x);
}
return 0;
}