P r o b l e m \mathrm{Problem} Problem
- 给定 n n n 个结点的无向完全图。每个点有一个点权为 a i a_i ai。连接 i i i 号结点和 j j j 号结点的边的边权为 a i ⊕ a j a_i\oplus a_j ai⊕aj。
- 求这个图的 MST 的权值。
- 1 ≤ n ≤ 2 × 1 0 5 1\le n\le 2\times 10^5 1≤n≤2×105, 0 ≤ a i ≤ 2 30 0\le a_i \le 2^{30} 0≤ai≤230。
S o l u t i o n \mathrm{Solution} Solution
这道题首先我们要意识到,异或我们可以考虑trie树。(啊呸意识到也不会)
根据位运算的运算原则,我们需要优先满足较高位。对于两个数
a
i
a_i
ai 和
a
j
a_j
aj 来说,贡献是从
l
c
a
(
i
,
j
)
\mathrm{lca}(i,j)
lca(i,j) 开始计算的。
(图出自洛谷题解我是白嫖怪)
如图,我们发现:
- 2 2 2 和 3 3 3 的贡献是在它们上面的公共祖先处计算的。
- 再比方说我们需要计算出 1 1 1 的贡献,那么我们肯定是从 2 2 2 和 3 3 3 那块子树里面去找点相连,而不是从块面的外树中找点相连,否则一定会造成位较高数位的贡献,一定会使答案更大。对于这种情况,我们需要在它们的 l c a \mathrm{lca} lca中去进行处理。
因此我们得到结论:一个子树的节点一定会在子树内部完成连边。
因此对于这种情况,我们以 x x x 为上面所说的 l c a \mathrm{lca} lca 点去处理这个问题。
- 若 x x x 的左子树和右子树都存在节点,我们可以在枚举左子树中的每一个节点,去右子树中查找。答案即为左右子树某两点相连的贡献+当前分叉的贡献+左子树的贡献+右子树的贡献。
- 若左子树存在节点,递归左子树。
- 若右子树存在节点,递归右子树。
那么我们现在的问题就在于如何枚举左子树的每一个节点,我们发现排序以后子树内的每一个节点都是一段连续的区间。因此我们只需要对每一个节点记录左右端点即可。
复杂度分析: t r i e \mathrm{trie} trie的插入询问单词 O ( log n ) O(\log n) O(logn),那么初始化复杂度为 O ( n log n ) O(n\log n) O(nlogn);对于主程序部分的处理中,由于深度最大为 log n \log n logn级别的,每一个最多只会被它的父亲进行查询操作,所有单个节点的最坏复杂度为 O ( log 2 n ) O(\log ^2 n) O(log2n)。
综上,总时间复杂度为: O ( n log 2 n ) O(n \log^2 n) O(nlog2n)
C o d e \mathrm{Code} Code
#include <bits/stdc++.h>
#define lc trie[x][0]
#define rc trie[x][1]
#define int long long
using namespace std;
const int N = 5e6;
int n, cnt, root;
int a[N], L[N], R[N], trie[N][2];
int read(void)
{
int s = 0, w = 0; char c = getchar();
while (!isdigit(c)) w |= c == '-', c = getchar();
while (isdigit(c)) s = s*10+c-48, c = getchar();
return w ? -s : s;
}
void insert(int &Now, int x, int i)
{
if (Now == 0) Now = ++ cnt;
L[Now] = min(L[Now], x);
R[Now] = max(R[Now], x);
if (i < 0) return;
int val = (a[x] >> i) & 1;
insert(trie[Now][val], x, i - 1);
return;
}
int ask(int Now, int x, int i)
{
if (i < 0) return 0;
int val = (x >> i) & 1;
if (trie[Now][val]) return ask(trie[Now][val], x, i - 1);
return ask(trie[Now][val^1], x, i - 1) + (1 << i);
}
int Dfs(int x, int dep)
{
if (dep < 0) return 0;
if (R[lc] and R[rc])
{
int res = 1e18;
for (int i=L[lc];i<=R[lc];++i)
res = min(res, ask(rc, a[i], dep - 1));
return Dfs(lc, dep - 1) + Dfs(rc, dep - 1) + res + (1 << dep);
}
if (R[lc]) return Dfs(lc, dep - 1);
if (R[rc]) return Dfs(rc, dep - 1);
return 0;
}
signed main(void)
{
freopen("qingyuqaq.in","r",stdin);
freopen("qingyuqaq.out","w",stdout);
n = read();
for (int i=1;i<=n;++i) a[i] = read();
sort(a + 1, a + n + 1);
memset(L, 30, sizeof L);
for (int i=1;i<=n;++i) insert(root, i, 30);
cout << Dfs(root, 30) << endl;
return 0;
}