题目链接:点击进入
题目
题意
给你 n 个点的值,n 个点加边生成一棵树,边权就是两点的异或值,要求最后所有边权和最小。输出最小边权和。
思路
将 n 个数插入到01字典树上,对于树上的每个节点,记录这个节点保存有原数组的那些数(就是原数组中有哪些数经过这个节点),因为这些数可能在原数组中是无序的所以不好保存,因此我们可以将原数组排序后再插入字典树,这样每个节点所保存的数就是连续的,也就是每个节点只记录这个连续区间的边界即可。
遍历01字典树,
对于两个孩子的节点,左右孩子代表的区间是不同的,代表未连接的两个集合(这里默认两个集合内部已经连接好了),连接这两个集合,最好的肯定是只用一条边将两个集合连起来,这条边肯定是两边所代表的区间内的数互相异或的最小值。因此我们可以枚举左子树所代表的所有数,对于每个数,在右子树上贪心求异或最小值,将所有枚举左子树得到的最小值取最小就是连接两个集合的最小代价。然后递归继续解决左右子树。
对于单个孩子的节点,只有一个集合,继续递归解决。
对于没有孩子的节点,递归返回。
最后将递归过程中的所有最小代价相加就是最终答案。
代码
//#pragma GCC optimize(3)//O3
//#pragma GCC optimize(2)//O2
#include<iostream>
#include<string>
#include<map>
#include<set>
//#include<unordered_map>
#include<queue>
#include<cstdio>
#include<vector>
#include<cstring>
#include<stack>
#include<algorithm>
#include<iomanip>
#include<cmath>
#include<fstream>
#define X first
#define Y second
#define base 233
#define pb push_back
#define INF 0x3f3f3f3f3f3f3f3f
#define pii pair<int,int>
#define lowbit(x) x & -x
#define inf 0x3f3f3f3f
// #define int long long
//#define double long double
//#define rep(i,x,y) for(register int i = x; i <= y;++i)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const double pai=acos(-1.0);
const int maxn=8e6+10;
const int mod=1e9+7;
const double eps=1e-9;
const int N=5e3+10;
/*--------------------------------------------*/
inline int read()
{
int k = 0, f = 1 ;
char c = getchar() ;
while(!isdigit(c)){if(c == '-') f = -1 ;c = getchar() ;}
while(isdigit(c)) k = (k << 1) + (k << 3) + c - 48 ,c = getchar() ;
return k * f ;
}
/*--------------------------------------------*/
int n,a[maxn],tot,val[maxn];
int t[maxn][2],l[maxn],r[maxn];
void insert(int pos)
{
int p=0;
for(int i=31;i>=0;i--)
{
int to=(a[pos]>>i)&1;
if(!t[p][to]) t[p][to]=++tot,l[tot]=r[tot]=pos;
p=t[p][to];
l[p]=min(l[p],pos);
r[p]=max(r[p],pos);
}
val[p]=pos;
}
int find(int p,int x,int deep)
{
for(int i=deep;i>=0;i--)
{
int to=(x>>i)&1;
if(t[p][to])
p=t[p][to];
else
p=t[p][!to];
}
return a[val[p]];
}
ll dfs(int p,int deep)
{
if(t[p][0]&&t[p][1])
{
ll minn=INF;
for(int i=l[t[p][0]];i<=r[t[p][0]];i++)
minn=min(minn,1LL*a[i]^find(t[p][1],a[i],deep-1));
return minn+dfs(t[p][0],deep-1)+dfs(t[p][1],deep-1);
}
else if(t[p][0]) return dfs(t[p][0],deep-1);
else if(t[p][1]) return dfs(t[p][1],deep-1);
else return 0;
}
int main()
{
// ios::sync_with_stdio(false);
// cin.tie(0);cout.tie(0);
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
sort(a+1,a+n+1);
for(int i=1;i<=n;i++) insert(i);
printf("%lld",dfs(0,31));
return 0;
}