题目
给你一个长度为
n
n
n 的数组
a
a
a,利用
a
i
a_{i}
ai、
a
j
a_{j}
aj可以构成一个权值为
a
i
∧
a
j
a_{i}^{\wedge}a_{j}
ai∧aj 的边。
要求求出利用数组构成的生成树的最小的权值和。
n
<
=
200000
n<=200000
n<=200000
a
i
<
=
(
1
<
<
31
)
−
1
a_{i}<=(1<<31)-1
ai<=(1<<31)−1
题解思路
解法1
官方题解中给出了利用
B
o
r
u
v
k
a
Boruvka
Boruvka算法来解决的思路。
有关这个算法的学习,翻到最下面
我们利用这个算法来求最小的生成树。
这样我们每次需要求出这个连通块和除了他本身处在的连通块中的元素能异或到的最小值,以及下标,并且把他们合并了。
求异或最值就可以想到利用01字典树。
而字典树本身是能和连通块一样进行合并的,每颗字典树就相当于点,利用并查集找出他归属的树根节点,并查集本身又能记录每个连通块的 s i z e size size,所以我们可以进行启发式合并,每次只合并 s i z e size size小的连通块。
字典树的合并思路
void dfs(int p , int q )
{
cnt[p] += cnt[q] ;
ed[p] = ed[q] ;
for (int i = 0 ; i < 2 ; i++ )
{
if ( tr[q][i] )
{
if (!tr[p][i])
tr[p][i] = tr[q][i] ;
else
dfs(tr[p][i],tr[q][i]) ;
}
}
}
void merge(int x , int y )
{
x = find(x) ;
y = find(y) ;
if ( x == y )
return ;
if ( sz[x] < sz[y] )
dfs(root[y],root[x]) , fa[x] = y , sz[y] += sz[x] ;
else
dfs(root[x],root[y]) , fa[y] = x , sz[x] += sz[y] ;
}
这样我们只能处理出每个连通块的情况,但是要求异或值改就需要遍历所有其他连通块?
这里利用之前可持续化的一种思想,我们可以另外建一个完整的字典树(既有全部元素)。并且在每个连通块都记录每个节点的
c
n
t
cnt
cnt,这样我们求最小异或值的时候,就能直接判断某个节点能否找到对应的点。
这里利用此时元素的所在的连通块的
c
n
t
cnt
cnt来与完整的作差,因为完整的肯定包含所有情况(有点可持续化的最终版本减之前的版本的感觉),利用差值判断其他连通块是否含有这个点。
int tp = (u >> i & 1) ;
if ( tr[pp][tp] && cnt[tr[pp][tp]] - cnt[tr[p][tp]] > 0 )
p = tr[p][tp] , pp = tr[pp][tp] ;
else
p = tr[p][!tp] , pp = tr[pp][!tp] , ans += 1ll<<i ;
我们找到每个连通快的最值就直接用树根节点和他们合并即可。
解法2
我们讲所有点先插入字典树,然后考虑按位获取每一位的贡献。
如果这个位有两条路径,那么他们直接异或的话,这个位必然会有值,我们需要将这个值最小化。
要保证最后图连通就必须要有一个
l
c
a
lca
lca 来将他们连接,
此时我们可以枚举左右子树的点来找出这个最小代价的
l
c
a
lca
lca 。
为了能顺序的枚举,我们可以先将原序列排序,来让子树的数产生一个连续的区间(仔细思考,此时的01字典树是一个2叉树,如果我们排序了,那么插入的值在字典树中呈现的肯定也是连续的)。
此时枚举将右子树的数放入左子树产生的字典树中贪心求最小代价即可。
相当于我们每次将数组分此时最大位贪心的合并。因为排序了,所以在字典树中也呈有序的。
AC代码
解法一
/*从你的全世界路过.*/
#include <bits/stdc++.h>
//#include <unordered_map>
//priority_queue
#define PII pair<long long,long long>
#define ll long long
using namespace std;
const int INF = 0x3f3f3f3f;
const int N = 200100;
int n ;
int a[N] ;
int tr[N*50][2] ;
int root[N] ;
int ed[N*50] ;
int cnt[N*50] ;
int fa[N] ;
int idx ;
long long dis[N] ;
int sz[N] ;
int disid[N] ;
int find(int x )
{
if ( fa[x] != x )
fa[x] = find(fa[x]) ;
return fa[x] ;
}
void dfs(int p , int q )
{
cnt[p] += cnt[q] ;
ed[p] = ed[q] ;
for (int i = 0 ; i < 2 ; i++ )
{
if ( tr[q][i] )
{
if (!tr[p][i])
tr[p][i] = tr[q][i] ;
else
dfs(tr[p][i],tr[q][i]) ;
}
}
}
void merge(int x , int y )
{
x = find(x) ;
y = find(y) ;
if ( x == y )
return ;
if ( sz[x] < sz[y] )
dfs(root[y],root[x]) , fa[x] = y , sz[y] += sz[x] ;
else
dfs(root[x],root[y]) , fa[y] = x , sz[x] += sz[y] ;
}
void insert(int p , int u , int id)
{
ed[p] = id ;
for (int i = 30 ; i >= 0 ; i-- )
{
int t = (u >> i & 1 ) ;
if (!tr[p][t])
idx++,tr[p][t] = idx ;
p = tr[p][t] ;
cnt[p]++ ;
ed[p] = id ;
}
}
PII query(int p , int u )
{
int pp = root[0] ;
long long ans = 0 ;
for (int i = 30 ; i >= 0 ; i-- )
{
int tp = (u >> i & 1) ;
if ( tr[pp][tp] && cnt[tr[pp][tp]] - cnt[tr[p][tp]] > 0 )
p = tr[p][tp] , pp = tr[pp][tp] ;
else
p = tr[p][!tp] , pp = tr[pp][!tp] , ans += 1ll<<i ;
}
return {ans,ed[pp]} ;
}
void solve()
{
cin >> n ;
for (int i = 1 ; i <= n ; i++ )
cin >> a[i] ;
sort(a+1,a+1+n) ;
idx++ ;
root[0] = idx ;
for (int i = 1 ; i <= n ; i++ )
{
idx++ ;
root[i] = idx ;
fa[i] = i ;
sz[i] = 1 ;
insert(root[i],a[i],i) ;
insert(root[0],a[i],i) ;
}
long long ans = 0 ;
while (1)
{
int falg = 0 ;
for (int i = 1 ; i <= n ; i++ )
dis[i] = (1ll<<31)-1 ;
for (int i = 1 ; i <= n ; i++ )
{
auto sk = query(root[find(i)],a[i]) ;
int x = find(i) , y = find(sk.second) ;
if ( x == y )
continue;
if ( sk.first < dis[x] )
dis[x] = sk.first , disid[x] = y ;
if ( sk.first < dis[y] )
dis[y] = sk.first , disid[y] = x ;
}
for (int i = 1 ; i <= n ; i++ )
{
if ( dis[i] < (1ll<<31)-1 && find(i) != find(disid[i]) )
{
merge(find(i),find(disid[i])) ;
ans += dis[i] ;
falg = 1 ;
}
}
if (!falg)
break;
}
cout << ans << "\n" ;
}
int main()
{
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
solve() ;
return 0 ;
}
解法二
/*从你的全世界路过.*/
#include <bits/stdc++.h>
//#include <unordered_map>
//priority_queue
#define PII pair<int,int>
#define ll long long
using namespace std;
const int INF = 0x3f3f3f3f;
const int N = 200100;
int n ;
long long a[N];
int tr[N*40][3] ;
int l[N*40] , r[N*40] ;
int idx ;
void insert( int u , int id )
{
int p = 1 ;
for (int i = 30 ; i >= 0 ; i-- )
{
int tp = u >> i & 1 ;
if ( !tr[p][tp] )
idx++,tr[p][tp] = idx ;
if ( !l[p] )
l[p] = id ;
r[p] = id ;
p = tr[p][tp] ;
}
if ( !l[p] )
l[p] = id ;
r[p] = id ;
}
long long query( int p , int u , int deth )
{
if ( deth == -1 )
return 0 ;
int tp = u >> deth & 1 ;
if ( tr[p][tp] )
return query(tr[p][tp],u,deth-1) ;
else
return query(tr[p][!tp],u,deth-1) + (1ll<<deth);
}
long long dfs(int p , int deth)
{
if ( deth == -1 )
return 0;
if ( tr[p][0] && tr[p][1] )
{
long long ans = 1e15 ;
for (int i = l[tr[p][0]] ; i <= r[tr[p][0]] ; i++ )
{
ans = min(ans,query(tr[p][1],a[i],deth-1)) ;
}
return dfs(tr[p][1],deth-1) + dfs(tr[p][0],deth-1) + ans + (1ll<<deth) ;
}else if ( tr[p][0] )
return dfs(tr[p][0],deth-1) ;
else if ( tr[p][1] )
return dfs(tr[p][1],deth-1) ;
return 0 ;
}
void solve()
{
cin >> n ;
for (int i = 1 ; i <= n ; i++ )
cin >> a[i] ;
sort(a+1,a+1+n) ;
idx++ ;
for (int i = 1 ; i <= n ; i++ )
insert(a[i],i) ;
cout << dfs(1,30) << "\n" ;
}
int main()
{
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
solve() ;
return 0 ;
}