题目链接: Strange Memory
大致题意
给定一棵有n个点的树, 根节点为1.
对于两个点a, b. 若满足 w[lca(a, b)] = w[a] ⊕ w[b] (w数组为该点的权值) , 则称(a, b)为满足要求的一个pair.
**pair(a, b)的权值为 a⊕b. **
问: 所有满足要求的pair的权值总和是多少. PS: pair(a, b)和pair(b, a)我们认为是同一个pair
解题思路
dsu on tree
我们先稍作分析:
由于题目中节点的权值≥1, 因此我们会发现, 任意两点的lca必然不是这两点中的任何一个点.
异或的性质: a ^ b = c <==> a ^ c = b
弱化版问题
我们先考虑这个问题的弱化版本, 如果我们想要统计所有满足要求的pair总对数, 我们如何去做?
dsu on tree.
我们可以枚举每一个点x作为点a和点b的lca时, pair(a, b)是否合法.
那么在进行计算的时候, 当我们遍历到点a时, 我们需要找出满足要求的点b.
已知 w[a] ^ w[b] = w[x], 但我们只知道w[x]和w[a], 不知道w[b]. 此时我们就可以运用异或的性质
==> w[b] = w[x] ^ w[a].
此时产生的贡献为, 权值为w[b]的点的数量. 我们可以通过一个数组, 来记录已经统计过的点中权值为val的点有多少个.
到这里, 会有两个细节问题:
Q: 我们遍历到点a后, 我们只统计了在它之前出现过的w[b]的数量, 如果后面还有呢?
A: 那当你后面再遍历到点b后, 此时你又会把贡献加回去, 是不会少计算贡献的.
~~
Q: 我们计算的时候, 已经订好了点x为lca节点, 如果我们计算的点a和点b都在一条链上, 此时他们的lca可能就不是点x了. 这种情况我们怎么处理?
A: 我们发现, 以点x为lca的两个点必然不在一条链上, 所以我们可以跑完一条链后, 再把这条链的贡献加入数组中. 这样我们就可以保证数组中已经计算的点必然与当前遍历到的点a不在一条链上, 符合要求.
到此, 这个弱化版问题就已经解决完毕了!
思路一
我们如果仿照弱化版的思路去做呢? 那么相当于这个题我就要对于每一个权值的位置开一个vector, 每次遍历到点a时, 我需要去遍历w[b]所在的vector.
此时我们分析一下复杂度: 假设有一半点的权值都是w[a], 另外一半都是w[b], 还有一个点x的权值是w[a]^w[b].
此时当我们以x为lca时, 我们成功跑了一次n2. 在考虑到dsu本身的logn. 所以这个方法最坏的复杂度会到 O(n2logn) 但是原题的数据比较随机, 居然能水过Orz. 我也会在后面放出这个方式的代码的.
到这里, 我们发现这个思路的复杂度已经爆炸了. 于是考虑优化.
优化的重点肯定是放在, 怎么不去遍历那个vector. 我们就得考虑异或的另外一个性质.
思路二
我们考虑到每次把两个数字给异或起来太慢了, 但是假设这个数字我们只看他二进制其中的一位呢?
即: 我现在有一个数字a, 我要把a和一个b[]数组中的所有数字都异或起来, 且他们二进制中我们只考虑一位的情况.
假设我们当前考虑的是二进制中的第k位. 如果a的第k位是1, 则b数组中只有第k位是0的这些数字会产生贡献. 反之亦然.
假设b数组中产生贡献的数字个数为num个, 那么对于答案的贡献是num * (1 << k)
.PS: 如果这里绕不明白, 你就认为这个数字的二进制只有一位数, 就别考虑第k位分析的情况了.
我们通过上述分析发现, 对于二进制每一位的情况, 我们可以O(1)的时间去统计答案.
题目中说明最大节点的权值为1E6, 那么参照上述思路, 我们只需要统计log2(1E6)位的情况即可.
到这里, 我们的思路就清晰了. 我们想出了一个复杂度为 O(nlognlog(1E6)) 的方法. 是可行的.
AC代码
思路一: https://pasteme.cn/126121
思路二: https://pasteme.cn/126122
/* 思路一 */
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 1; i <= (n); ++i)
#define debug(a) cout << #a << " = " << a << endl;
using namespace std;
typedef long long ll;
const int N = 1E5 + 10;
ll res = 0; int pval; //当前lca的权值
int w[N];
vector<int> edge[N];
int son[N], sz[N];
void dfs1(int x, int fa) { //得到重儿子
sz[x] = 1;
for (auto& to : edge[x]) {
if (to == fa) continue;
dfs1(to, x);
sz[x] += sz[to];
if (sz[to] > sz[son[x]]) son[x] = to;
}
}
vector<int> v; //v数组存储当前链上遍历过的所有节点
unordered_map<int, vector<int>> mp; //val [index];
void calc(int x, int fa) {
v.push_back(x);
int target = pval ^ w[x]; //要判断一下, 超过1E6必然没有解
if (target <= 1E6 + 5) for (auto& op : mp[target]) res += x ^ op;
for (auto& to : edge[x]) {
if (to == fa) continue;
calc(to, x);
}
}
void dfs2(int x, int fa, int tp) {
for (auto& to : edge[x]) {
if (to == fa or to == son[x]) continue;
dfs2(to, x, 0);
}
if (son[x]) { //遍历重儿子
dfs2(son[x], x, 1);
mp[w[son[x]]].push_back(son[x]); //dfs2不会更新x节点的信息, 因此会缺少son[x]的信息
}
pval = w[x];
for (auto& to : edge[x]) {
if (to == fa or to == son[x]) continue;
calc(to, x);
for (auto& op : v) mp[w[op]].push_back(op); //更新当前链的信息
v.clear();
}
if (!tp) mp.clear();
}
int main()
{
int n; cin >> n;
rep(i, n) scanf("%d", &w[i]);
rep(i, n - 1) {
int a, b; scanf("%d %d", &a, &b);
edge[a].push_back(b), edge[b].push_back(a);
}
dfs1(1, 0);
dfs2(1, 0, 1);
cout << res << endl;
return 0;
}
/* 思路二 */
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 1; i <= (n); ++i)
#define debug(a) cout << #a << " = " << a << endl;
using namespace std;
typedef long long ll;
const int N = 1E5 + 10, LEN = 20;
ll res = 0; int pval; //当前lca的权值
int w[N];
vector<int> edge[N];
int son[N], sz[N];
void dfs1(int x, int fa) { //得到重儿子
sz[x] = 1;
for (auto& to : edge[x]) {
if (to == fa) continue;
dfs1(to, x);
sz[x] += sz[to];
if (sz[to] > sz[son[x]]) son[x] = to;
}
}
vector<int> v; //v数组存储当前链上遍历过的所有节点
int cou[int(1E6 + 10)][LEN][2]; //存储数位信息
/* fact函数用于计算数位情况, 因为要计算的地方比较多, 干脆写个函数 */
void fact(int x, int val, int c) { for (int i = 0; i < LEN; ++i) cou[val][i][x >> i & 1] += c; }
void calc(int x, int fa) {
v.push_back(x);
int target = pval ^ w[x]; //要判断一下, 超过1E6必然没有解
if (target < 1E6 + 5) for (int i = 0; i < LEN; ++i) res += cou[target][i][(x >> i & 1) ^ 1] * (1 << i);
for (auto& to : edge[x]) {
if (to == fa) continue;
calc(to, x);
}
}
void del(int x, int fa) {
fact(x, w[x], -1);
for (auto& to : edge[x]) {
if (to == fa) continue;
del(to, x);
}
}
void dfs2(int x, int fa, int tp) {
for (auto& to : edge[x]) {
if (to == fa or to == son[x]) continue;
dfs2(to, x, 0);
}
if (son[x]) dfs2(son[x], x, 1);
pval = w[x];
for (auto& to : edge[x]) {
if (to == fa or to == son[x]) continue;
calc(to, x);
for (auto& op : v) fact(op, w[op], 1); //更新当前链的信息
v.clear();
}
fact(x, w[x], 1); //最后要更新当前点的信息, 以便正确清空数据, 或正确保存数据.
if (!tp) del(x, fa);
}
int main()
{
int n; cin >> n;
rep(i, n) scanf("%d", &w[i]);
rep(i, n - 1) {
int a, b; scanf("%d %d", &a, &b);
edge[a].push_back(b), edge[b].push_back(a);
}
dfs1(1, 0);
dfs2(1, 0, 1);
cout << res << endl;
return 0;
}