题目大意
给定数组 { A n } \{A_n\} {An}和 { B n } \{B_n\} {Bn},求 ∑ 1 ≤ i < j ≤ n min ( A i ⊕ A j , B i ⊕ B j ) \sum_{1\le i<j\le n}\min(A_i\oplus A_j,B_i\oplus B_j) ∑1≤i<j≤nmin(Ai⊕Aj,Bi⊕Bj), n ≤ 250000 n\le 250000 n≤250000
题解
这题还是很妙的,要求的东西看上去没什么关联,所以想办法找点关联,发现
A
i
⊕
A
j
⊕
B
i
⊕
B
j
A_i\oplus A_j\oplus B_i\oplus B_j
Ai⊕Aj⊕Bi⊕Bj的最高位
1
1
1的位置是
A
i
⊕
A
j
A_i\oplus A_j
Ai⊕Aj和
B
i
⊕
B
j
B_i\oplus B_j
Bi⊕Bj最高位不同的位置
设
C
i
=
A
i
⊕
B
i
C_i=A_i\oplus B_i
Ci=Ai⊕Bi,考虑根据上面那条性质分治,每次把在当前位数
d
e
p
dep
dep为
0
0
0的
C
i
C_i
Ci放到一个集合中,为
1
1
1的
C
i
C_i
Ci放到另一个集合中,那么这两个集合之间
A
i
⊕
A
j
A_i\oplus A_j
Ai⊕Aj和
B
i
⊕
B
j
B_i\oplus B_j
Bi⊕Bj不同的最高位就是
d
e
p
dep
dep,判断它们的大小关系只需要判断它们在
d
e
p
dep
dep位上的大小关系,讨论一下算出贡献,再分治下去,分治的每一层算贡献的复杂度为
O
(
n
log
n
)
O(n\log n)
O(nlogn),分治深度为
O
(
log
n
)
O(\log n)
O(logn),所以总时间复杂度为
O
(
n
log
2
n
)
O(n\log^2n)
O(nlog2n)
code
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;
void read(int &res)
{
res=0;char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while('0'<=ch&&ch<='9') res=(res<<1)+(res<<3)+(ch^48),ch=getchar();
}
const int N=3e5+100,Bl=19;
int n,a[N+10],b[N+10],p[N+10],p0[N+10],p1[N+10],c[2][Bl+10],A[2][2][Bl+10],B[2][2][Bl+10];
ll solve(int l,int r,int dep=Bl)
{
if(l>r) return 0;
if(dep<0)
{
ll res=0;
for(int i=0;i<=Bl;i++) c[0][i]=c[1][i]=0;
for(int i=l;i<=r;i++)
{
for(int j=0,k;j<=Bl;j++) k=((a[p[i]]&(1<<j))!=0),res+=1ll*c[k^1][j]*(1<<j);
for(int j=0,k;j<=Bl;j++) k=((a[p[i]]&(1<<j))!=0),c[k][j]++;
}
return res;
}
ll res=0;p0[0]=p1[0]=0;
for(int i=l;i<=r;i++)
{
if((a[p[i]]^b[p[i]])&(1<<dep)) p0[++p0[0]]=p[i];
else p1[++p1[0]]=p[i];
}
for(int j=0;j<=Bl;j++) for(int i=0;i<=1;i++) for(int k=0;k<=1;k++) A[i][k][j]=B[i][k][j]=0;
for(int i=1,k1;i<=p0[0];i++)
{
k1=((a[p0[i]]&(1<<dep))!=0);
for(int j=0,k;j<=Bl;j++)
{
k=((a[p0[i]]&(1<<j))!=0),A[k1][k][j]++,
k=((b[p0[i]]&(1<<j))!=0),B[k1][k][j]++;
}
}
for(int i=1,k1;i<=p1[0];i++)
{
k1=((a[p1[i]]&(1<<dep))!=0);
for(int j=0,k;j<=Bl;j++)
{
k=((a[p1[i]]&(1<<j))!=0);
res+=1ll*A[k1][k^1][j]*(1<<j);
k=((b[p1[i]]&(1<<j))!=0);
res+=1ll*B[k1^1][k^1][j]*(1<<j);
}
}
for(int i=1;i<=p0[0];i++) p[l+i-1]=p0[i];
for(int i=1;i<=p1[0];i++) p[l+p0[0]-1+i]=p1[i];
int mid=l+p0[0]-1;
res+=solve(l,mid,dep-1)+solve(mid+1,r,dep-1);
return res;
}
int main()
{
read(n);
for(int i=1;i<=n;i++) p[i]=i;
for(int i=1;i<=n;i++) read(a[i]);
for(int i=1;i<=n;i++) read(b[i]);
printf("%lld",solve(1,n,Bl));
return 0;
}