链接: Leetcode 5696
题意:
给出一个 num 数组 , 和 low ,high,求数组中有多少个点对(i , j)满足 num[i] ^ num[j]
≥
\geq
≥ low && num[i] ^ num[j]
≤
\leq
≤ high.
思路 :
求异或值满足一个范围,肯定是字典树 , 我们可以对每一个数单独考虑。对于每一个数,从高位开始,如果某一位的运算结果大于 low 那后面的 就可以随便是什么了,high也同理 , 所以只要递归往后找每一位,并维护 两个条件 ,当前是否已经大于 low,当前是否已经小于 high.如果两个条件同时满足,就算贡献就好了。
代码:
#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<string.h>
#include<vector>
#include<cmath>
#include<string>
#include<map>
#include<queue>
using namespace std;
typedef long long ll;
const int maxn= 3e5+7 ;
int a[maxn],ans;
int tot=1,tire[maxn][3],cnt[maxn],n;
int low,high,now;
void Insert(int x) {
int p=0,t;
for(int i = 20; i >= 0; i --) {
t = ((x >> i) & 1);
if(tire[p][t] == 0) {
tire[p][t]= tot ++;
}
p=tire[p][t];
cnt[p]++;
}
}
void Find(int pos , int p , bool t1, bool t2) {
if(t1 && t2 || pos == -1){
ans += cnt[p];
return;
}
int n1 = tire[p][0];
int n2 = tire[p][1];
int bp = ((now >> pos) & 1);
int bl = ((low >> pos) & 1);
int bh = ((high >> pos) & 1);
if(n1 != 0){
int x1 = (bp ^ 0);
if((x1 >= bl && x1 <= bh) || (x1 <= bh && t1) || (x1 >= bl && t2)){
Find(pos - 1 , n1 , max(t1 , (x1 > bl)) , max(t2 , (x1 < bh)));
}
}
if(n2 != 0){
int x1 = (bp ^ 1);
if((x1 >= bl && x1 <= bh) || (x1 <= bh && t1) || (x1 >= bl && t2)){
Find(pos - 1 , n2 , max(t1 , (x1 > bl)) , max(t2 , (x1 < bh)));
}
}
}
int main(){
scanf("%d%d%d",&n,&low,&high);
for(int i = 1; i <= n; i ++){
scanf("%d",&a[i]);
Insert(a[i]);
}
for(int i = 1; i <= n; i ++){
now = a[i];
Find(20 , 0 , 0 , 0);
}
printf ("%d\n",ans / 2);
}
过题代码:
class Solution {
int ans = 0;
int tot=1,tire[200000][3],cnt[200000];
int low,high,now;
public:
void Insert(int x) {
int p=0,t;
for(int i = 20; i >= 0; i --) {
t = ((x >> i) & 1);
if(tire[p][t] == 0) {
tire[p][t]= tot ++;
}
p=tire[p][t];
cnt[p]++;
}
}
void Find(int pos , int p , bool t1, bool t2,int low,int high) {
if(t1 && t2 || pos == -1){
ans += cnt[p];
return;
}
int n1 = tire[p][0];
int n2 = tire[p][1];
int bp = ((now >> pos) & 1);
int bl = ((low >> pos) & 1);
int bh = ((high >> pos) & 1);
if(n1 != 0){
int x1 = (bp ^ 0);
if((x1 >= bl && x1 <= bh) || (x1 <= bh && t1) || (x1 >= bl && t2)) Find(pos - 1 , n1 , max(t1 , (x1 > bl)) , max(t2 , (x1 < bh)) , low,high);
}
if(n2 != 0){
int x1 = (bp ^ 1);
if((x1 >= bl && x1 <= bh) || (x1 <= bh && t1) || (x1 >= bl && t2)) Find(pos - 1 , n2 , max(t1 , (x1 > bl)) , max(t2 , (x1 < bh)),low,high);
}
}
int countPairs(vector<int>& nums, int low, int high) {
for(int i = 0; i < nums.size(); i ++){
Insert(nums[i]);
}
for(int i = 0; i <nums.size(); i ++){
now = nums[i];
Find(20 , 0 , 0 , 0,low,high);
}
return ans / 2;
}
};
新写法,先求出合法区间再求贡献
类似题目 :添加链接描述
#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<string.h>
#include<vector>
#include<cmath>
#include<string>
#include<map>
#include<queue>
using namespace std;
typedef long long ll;
const int maxn= 4e6+7 ;
int a[maxn],ans;
int tot=1,sum[maxn],cnt[maxn],n;
int low,high,now;
void update(int pos, int l , int r, int rt){
if(l == r){
sum[rt] ++;
return ;
}
int mid = (l + r) / 2;
if(pos <= mid) update(pos , l , mid , rt << 1);
if(pos > mid) update(pos , mid + 1 ,r , rt << 1 | 1);
sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
}
int query(int L ,int R, int l , int r , int rt){ //求区间的贡献
int s = 0;
if(L <= l && R >= r){
return sum[rt];
}
int mid = (l + r) / 2;
if(L <= mid) s += query(L , R , l , mid , rt << 1);
if(R > mid) s += query(L , R , mid + 1 , r , rt << 1 | 1);
return s;
}
int Find(int L , int R ,int val , int l , int r , int pos){ //求出合法区间
int sum = 0;
if(L <= l && R >= r){
int ql = (l ^ val) & (((1 << 20) - 1) ^ ((1 << pos) - 1));
int qr = ql + (1 << pos) - 1;
return query(ql , qr , 0 , (1 << 20) , 1);
}
int mid = (l + r) >> 1;
if(L <= mid) sum += Find(L , R ,val , l , mid , pos - 1);
if(R > mid) sum += Find(L , R , val , mid + 1 , r , pos - 1);
return sum;
}
int main(){
scanf("%d%d%d",&n,&low,&high);
for(int i = 1; i <= n; i ++){
scanf("%d",&a[i]);
update(a[i] , 0 , (1 << 20) , 1);
}
for(int i = 1; i <= n; i ++){
now = a[i];
if(low) ans -= Find(0 , low - 1 , now , 0 , (1 << 20) - 1 ,20);
ans += Find(0 , high , now , 0 , (1 << 20) - 1 ,20);
}
printf ("%d\n",ans / 2);
}
/*
5 5 14
9 8 4 2 1
*/