解题思路:这题字典树比较好处理,我们肯定是从P中找出按高位优先相同位的来找,也就是说,从高位开始匹配尽量找到与Ai最像的,如果在第j位没有的话只能在ans上加上了,然后删除的话也很好删除,只要将节点v值减一就行了。
另外一个set模拟有点难理解,也难说就不说了,下面也有代码
字典树:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mx = 3e5 + 10;
int n,m,Ai[mx],root,tot = 0;
bool bits[33];
struct node
{
int num;
int val;
node(){}
node(int nu,int va):num(nu),val(va){}
}nxt[mx*32][2];
void Getbit(int x)
{
memset(bits,0,sizeof(bits));
int top = 0;
while(x){
if(x&1) bits[top] = 1;
x >>= 1;
top++;
}
reverse(bits,bits+31);
}
int newnode()
{
nxt[tot][0].num = nxt[tot][1].num = -1;
nxt[tot][0].val = nxt[tot][1].val = 0;
return tot++;
}
void insert()
{
int dep = root;
for(int i=0;i<31;i++){
if(nxt[dep][bits[i]].num==-1)
nxt[dep][bits[i]].num = newnode();
nxt[dep][bits[i]].val++;
dep = nxt[dep][bits[i]].num;
}
}
int find(bool *px)
{
int ans = 0,dep = root;
for(int i=0;i<31;i++){
if(!nxt[dep][bits[i]].val){
ans += 1<<(30-i);
nxt[dep][bits[i]^1].val--;
dep = nxt[dep][bits[i]^1].num;
}else{
nxt[dep][bits[i]].val--;
dep = nxt[dep][bits[i]].num;
}
}
//cout << "ans :" << ans << endl;
return ans;
}
int main()
{
scanf("%d",&n);
root = newnode();
//memset(nxt[root],-1,sizeof(nxt[root]));
for(int i=0;i<n;i++) scanf("%d",Ai + i);
for(int i=0;i<n;i++){
scanf("%d",&m);
//memset(bits,0,sizeof(bits));
Getbit(m);
insert();
}
for(int i=0;i<n;i++){
Getbit(Ai[i]);
printf("%d%c",find(bits),i==n-1?'\n':' ');
}
return 0;
}
set模拟:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mx = 3e5 + 10;
int n,m,Ai[mx],bits[33],ans,top;
multiset <int> st[33];
void Getbit(int x)
{
memset(bits,0,sizeof(bits));
top = 0;
while(x){
if(x&1) bits[top] = 1;
x >>= 1;
top++;
}
}
bool vis[33];
bool Get_in(int y,int x)
{
int ret = 1<<y,maxx = 2e9;
for(int i=x;i>=0;i--){
if(bits[i]){
auto it = st[y].lower_bound(ret+(1<<i));
if(it!=st[y].end()&&(*it)<maxx) ret += 1<<i;
else ans += 1<<i;
}else{
auto it = st[y].lower_bound(ret+(1<<i));
if(it==st[y].begin()||(*(--it))<ret) ans += 1<<i,ret += 1<<i;
else maxx = ret + (1<<i);
}
}
st[y].erase(st[y].lower_bound(ret));
return 1;
}
int main()
{
scanf("%d",&n);
for(int i=0;i<n;i++) scanf("%d",Ai + i);
for(int i=0;i<n;i++){
scanf("%d",&m);
int num = 0,val = m;
while(m>>1){
m >>= 1;
num++;
}
st[num].insert(val);
}
for(int i=0;i<n;i++){
ans = 0,Getbit(Ai[i]);
int flag = 0;
for(int j=top-1;j>=0;j--)
{
if(bits[j]&&st[j].size())
{
flag = Get_in(j,j-1);
break;
}
if(bits[j]) ans += 1<<j;
}
if(!flag){
if(st[0].count(0))
{
st[0].erase(st[0].lower_bound(0));
printf("%d%c",Ai[i],i==n-1?'\n':' ');
continue;
}
for(int j=0;j<32;j++){
if(st[j].size()){
ans = 1<<j;
Get_in(j,j-1);
for(int k=j+1;k<32;k++)
if(bits[k]) ans += 1<<k;
break;
}
}
}
printf("%d%c",ans,i==n-1?'\n':' ');
}
return 0;
}