XOR Inverse
XOR Inverse CF1419C
题意:
给定n个数字,要求给出最小的
x
x
x使得
b
i
=
a
i
⊕
x
b_i=a_i\oplus x
bi=ai⊕x之后的序列的逆序对总和最小
思路:
遇到二进制必贪心定理^^,但是这边得转化一下,考虑每一位取
0
0
0
o
r
or
or
1
1
1的代价怎么计算。
因为逆序对的定义是
i
<
j
&
&
b
i
>
b
j
i<j \&\& b_i>b_j
i<j&&bi>bj,这里涉及到比大小,放到二进制中就是对于第一位不同的位数进行大小比较,是0就小,是1就大。
这里对于第一位不同的可以想到用Trie把前缀相同的数字合并,在进行
T
r
i
e
树
d
f
s
Trie树dfs
Trie树dfs时候对于第
i
i
i位取
0
0
0
o
r
or
or
1
1
1的代价进行记录,(代码里面我用
d
p
[
i
]
[
b
i
t
]
dp[i][bit]
dp[i][bit]进行记录):
- d p [ d e p ] [ 0 ] dp[dep][0] dp[dep][0]表示在 d e p dep dep位取 1 1 1所付出的代价,这种情况 T r i e Trie Trie内的取 0 0 0的子树在 ⊕ \oplus ⊕之后反而更大,所以就是找到子树 0 0 0内的 i d id id小于子树 1 1 1内 i d id id的个数。大小计算可以遍历子树 0 0 0内寻找当前 i d ( 代 码 写 的 m p [ T r i e [ n o d e ] [ 0 ] ] [ i ] ) id(代码写的mp[Trie[node][0]][i]) id(代码写的mp[Trie[node][0]][i])小于子树 1 1 1内的 i d ( 代 码 写 的 m p [ T r i e [ n o d e ] [ 1 ] ] [ t t ] ) id(代码写的mp[Trie[node][1]][tt]) id(代码写的mp[Trie[node][1]][tt])
- d p [ d e p ] [ 1 ] dp[dep][1] dp[dep][1]表示在 d e p dep dep位取 0 0 0所付出的代价,这种情况可以利用上一种情况已经统计出的逆序对,对所有的序列对 ( m p [ T r i e [ n o d e ] [ 0 ] ] . s i z e ( ) ∗ m p [ T r i e [ n o d e ] [ 1 ] ] . s i z e ( ) ) (mp[Trie[node][0]].size()*mp[Trie[node][1]].size()) (mp[Trie[node][0]].size()∗mp[Trie[node][1]].size())减去上一问的,就是本问的答案。
注意点:宝贝开LL^^,maxn开大一个数量级( 4 e 6 4e6 4e6)
int Trie[maxn][2], tot = 1, a[maxn];
LL dp[maxn][2];
vector<int> mp[maxn];
void insert(int x,int id) {
int ch, p = 1;
for (int i = 30; i >= 0; i--) {
ch = x >> i & 1;
if (Trie[p][ch] == 0)Trie[p][ch] = ++tot;
p = Trie[p][ch];
mp[p].push_back(id);
}
return;
}
void dfs(int dep,int node) {
if (Trie[node][0])dfs(dep - 1, Trie[node][0]);
if (Trie[node][1])dfs(dep - 1, Trie[node][1]);
if (Trie[node][0] == 0 || Trie[node][1] == 0)return;
LL sum=0;
LL tt = 0;
for (int i = 0; i < mp[Trie[node][0]].size(); i++) {
while (tt < mp[Trie[node][1]].size() && mp[Trie[node][1]][tt] < mp[Trie[node][0]][i]) tt++;
sum += tt;
}
dp[dep][0] += sum;
dp[dep][1] += (LL)mp[Trie[node][0]].size()*(LL)mp[Trie[node][1]].size() - sum;
}
int main() {
int T, n, x;
LL ans = 0, u = 0;
//sci(T);
T = 1;
//cout << qpow(2, 30) << endl;
while (T--){
sci(n);
for (int i = 1; i <= n; i++) {
sci(a[i]);
insert(a[i], i);
}
dfs(30, 1);
for (int i = 0; i <= 30; i++) {
ans += min(dp[i][0], dp[i][1]);
if (dp[i][1] < dp[i][0])u += (1 << i);
}
printf("%lld %lld\n", ans, u);
}
return 0;
}