计算点对相关的一种dsu on tree做法;
个人总结与普通的区别主要是:在count处变为枚举轻子树累计答案,枚举后加入并继续枚举下一个,最后删除。.
这题还要借助数组来把数字变为01数组以加速异或加法。
bit[num[u]][i][(u>>i)&1] += val;
分别为节点值,位数,位为0或1。
#include "bits/stdc++.h"
//#define int long long
using namespace std;
const int maxn = 3e5;
int cnt = 0;
int head[maxn];
struct node {
int to, next;
} a[maxn];
void add(int u, int v) {
a[++cnt].to = v;
a[cnt].next = head[u];
head[u] = cnt;
}
int num[maxn];
int siz[maxn];
int son[maxn];
void dfsFir(int u, int fa) {
siz[u] = 1;
for (int i = head[u]; i; i = a[i].next) {
int v = a[i].to;
if (v == fa) continue;
dfsFir(v, u);
siz[u] += siz[v];
if (siz[v] > siz[son[u]]) {
son[u] = v;
}
}
}
int flag;
long long ans;
int bit[1100000][21][2];
void count(int u, int fa, int val) {
for (int i = 0; i <= 20; ++i) {
bit[num[u]][i][(u>>i)&1] += val;
}
for (int i = head[u]; i ; i = a[i].next) {
int v = a[i].to;
if (v == fa || v == flag) continue;
count(v,u,val);
}
}
void get(int u,int fa,int t) //针对每颗轻子树先得到贡献,再对其进行count加上数量
{
int val = (num[u]^num[t]); //找值为val的点
if (val <= (1<<20))
{
for (int i = 0; i <= 20; ++i) {
ans += (1ll<<i)*bit[val][i][!((u>>i)&1)]; //因为是异或的和所以加!
}
}
for (int i = head[u]; i ; i = a[i].next) {
int v = a[i].to;
if (v == fa) continue;
get(v,u,t);
}
}
void dfsSec(int u, int fa, int keep) {
for (int i = head[u]; i; i = a[i].next) {
int v = a[i].to;
if (v == fa || v == son[u]) continue;
dfsSec(v, u, false);
}
if (son[u]) {
dfsSec(son[u], u, true);
flag = son[u];
}
for (int i = head[u]; i ; i = a[i].next) { //count,点对的做法,枚举子树累计答案,枚举后加入继续枚举下一个,最后删除
int v = a[i].to;
if (v == fa || v == flag) continue;
get(v,u,u);
count(v,u,1);
}
flag = 0;
for (int i = 0; i <= 20; ++i) { //顶点没记
bit[num[u]][i][(u>>i)&1] += 1;
}
if (!keep) {
count(u, fa, -1);
}
}
signed main() {
ios::sync_with_stdio(0);
int n, m;
cin >> n;
for (int i = 1; i <= n; ++i) {
cin >> num[i];
}
for (int i = 1; i <= n-1; ++i) {
int u,v;
cin >> u >> v;
add(u,v);
add(v,u);
}
dfsFir(1,0);
dfsSec(1,0,0);
cout << ans << endl;
}