题意
对一个有点权的树,求以下公式的值:
∑ i = 1 n − 1 ∑ j = i + 1 n [ a i X O R a j = = a L C A ( i , j ) ] ( i X O R j ) \sum_{i = 1}^{n - 1} \sum_{j = i + 1}^{n} [a_i \ XOR \ a_j == a_{LCA(i, j)}] (i \ XOR \ j) i=1∑n−1j=i+1∑n[ai XOR aj==aLCA(i,j)](i XOR j)
思路
总体上看是比较套路的 dsu on tree
,树上启发式合并。
如果对这个算法一无所知,那先学习一下入门基础,然后做一下典型例题,不算太难。
再看这道题:
- 首先考虑要求的东西是异或和,想到拆位处理
- 对于树上每个点 w 都有
p[w]
表示 w 在当前位是 0 或者 1。
- 对于树上每个点 w 都有
- 考虑一个比较暴力的 dfs :
- 每个点 x 都有
cnt[x][1e6][2]
- 表示 x 点有一个 1e6 长度的数组,统计 0 和 1 的个数
- 现在考虑遍历到一个点 u
- 枚举 u 的子节点 v
- 遍历子树
subtree[v]
的所有结点 wres += cnt[u][a[u] ^ a[w]][p[w] ^ 1]
这表示 cnt 之前统计的点中,存在与a[w]
异或等于a[u]
, 并且在当前枚举的二进制位与p[w]
不同的点个数。
- 再次遍历子树
subtree[v]
的所有结点 wcnt[u][a[w]][p[w]]++
把 w 这个点加入总父亲的统计。- 这一步必须和上面分开遍历,因为我们不能让子树内部产生贡献。
- 遍历子树
- 每个点 x 都有
- 但显然我们不能开那么大的数组,这个时候就用
dsu on tree
让cnt可以复用,减去第一维。 - 如果按照传统的写法,会需要很多递归去实现原来得暴力的部分。我这里记了 dfs 序维护所有子树的区间,直接用 for 区间的方式枚举子树,比较容易写。
第三点就是基于自己对dsu on tree
模板的理解慢慢摸出来的,比较套路,没什么特别的技巧。g关键是cnt数组的复用。
不恰当地比喻,就像是把树拆成了几条重链,用技巧按顺序枚举这几条重链,轻链几乎是和原来差不多的方式暴力枚举。
复杂度应该是: N l o g 2 N Nlog^2N Nlog2N
代码
const int MAXN = 2e5 + 59;
const int MAXN_2e6 = 2e6 + 59;
using namespace std;
int n;
int c[MAXN];
int p[MAXN];
int sz[MAXN];
int son[MAXN];
int L[MAXN], R[MAXN], Id[MAXN];
ll ans;
vector<int> g[MAXN];
void set_son(int u, int f) {
L[u] = ++*L;
Id[*L] = u;
sz[u] = 1;
son[u] = -1;
for (auto v: g[u]) {
if (v == f) continue;
set_son(v, u);
sz[u] += sz[v];
if (son[u] == -1 || sz[v] > sz[son[u]]) {
son[u] = v;
}
}
R[u] = *L;
}
int cnt[MAXN_2e6][2];
ll res;
void dsu(int u, int f, bool ish) {
for (auto v: g[u]) {
if (v == f) continue;
if (v == son[u]) continue;
dsu(v, u, false);
}
if (son[u] > 0) dsu(son[u], u, true);
for (auto v: g[u]) {
if (v == f) continue;
if (v == son[u]) continue;
for (int i = L[v], w = Id[i]; i <= R[v]; w = Id[++i]) {
res += cnt[c[u] ^ c[w]][p[w] ^ 1];
}
for (int i = L[v], w = Id[i]; i <= R[v]; w = Id[++i]) {
cnt[c[w]][p[w]]++;
}
}
cnt[c[u]][p[u]]++;
if (!ish) {
for (int i = L[u], w = Id[i]; i <= R[u]; w = Id[++i]) {
cnt[c[w]][0] = cnt[c[w]][1] = 0;
}
}
}
void solve(int kaseId = -1) {
cin >> n;
for (int i = 1; i <= n; ++i) {
cin >> c[i];
}
for (int i = 2, u, v; i <= n; ++i) {
cin >> u >> v;
g[u].emplace_back(v);
g[v].emplace_back(u);
}
set_son(1, -1);
for (int k = 0; k < 22; ++k) {
int kp = 1 << k;
if (kp > n) break;
res = 0;
for (int i = 1; i <= n; ++i) {
p[i] = (i >> k) & 1;
}
dsu(1, -1, false);
// debug(res);
ans = ans + res * kp;
}
cout << ans << endl;
}