本题采用DSU On Tree + 拆位,这个题目采用启发式合并很容易想到,关键是如何加速计算贡献的过程,首先,遍历每一个点然后寻找前缀树中对应的点是必不可少的,但是这里,我们可以开一个数组A[X][Y][Z],前缀树中,设一个点的编号为ID,则数组表示在前缀树中所有权值为X的编号中第Y位为Z的数量,这样的话,我们只需要对遍历到的点进行二进制拆位,检查前缀树中0/1的数量便可以加速计算贡献的过程了
//#define LOCAL#include<bits/stdc++.h>usingnamespace std;#define ll long long#define mem(a, b) memset(a,b,sizeof(a))#define INF 0x3f3f3f3f#define DNF 0x7f#define DBG printf("this is a input\n")#define fi first#define se second#define mk(a, b) make_pair(a,b)#define pb push_back#define LF putchar('\n')#define SP putchar(' ')#define p_queue priority_queue#define CLOSE ios::sync_with_stdio(0); cin.tie(0)template<typename T>voidread(T &x){x =0;char ch =getchar();ll f =1;while(!isdigit(ch)){if(ch =='-')f *=-1;ch =getchar();}while(isdigit(ch)){x = x *10+ ch -48; ch =getchar();}x *= f;}template<typename T,typename... Args>voidread(T &first, Args&... args){read(first);read(args...);}template<typename T>voidwrite(T arg){T x = arg;if(x <0){putchar('-'); x =- x;}if(x >9){write(x /10);}putchar(x %10+'0');}template<typename T,typename... Ts>voidwrite(T arg, Ts ... args){write(arg);if(sizeof...(args)!=0){putchar(' ');write(args ...);}}usingnamespace std;
ll gcd(ll a, ll b){return b ==0? a :gcd(b, a % b);}
ll lcm(ll a, ll b){return a /gcd(a, b)* b;}constint N =2e6+5;int n;int c[N], siz[N], son[N], pp[N], vis[N];
ll sum;int t[N][25][2];
vector <int> edge[N];void dfs1 (int u,int fa){
siz[u]=1;for(auto v : edge[u]){if(v != fa){
dfs1 (v, u);
siz[u]+= siz[v];if(siz[v]> siz[son[u]])
son[u]= v;}}}void add (int u ,int k){for(int i =0; i <21; i ++)
t[c[u]][i][(u >> i)&1]+= k;}void get (int u ,int fa,int lca){int num = u;int temp =(c[lca]^ c[u]);for(int i =0; i <21; i ++)
sum +=1ll* t[temp][i][(1^((num>>i)&1))]* pp[i];for(auto v : edge[u])if(v != fa)
get (v, u, lca);}void seet (int u ,int fa,int x){
add (u, x);for(auto v : edge[u])if(v != fa)
seet (v, u, x);}void count (int u ,int fa,int k){for(auto v : edge[u]){if(v != fa &&!vis[v]){if(k ==1) get (v , u , u);
seet (v , u ,k);}}
add (u, k);}void dfs2 (int u ,int fa,int opt){for(auto v : edge[u])if(v != fa && v != son[u])dfs2(v, u,0);if(son[u])dfs2(son[u], u,1), vis[son[u]]=1;
count (u, fa,1);if(son[u]) vis[son[u]]=0;if(!opt)count(u ,fa,-1);}intmain(){
read (n);
pp[0]=1;for(int i =1; i <21; i ++)
pp[i]= pp[i -1]*2;for(int i =1; i <= n ; i ++)
read (c[i]);for(int i =1; i < n ; i ++){int u , v;
read (u , v);
edge[u].pb(v);
edge[v].pb(u);}
dfs1 (1,0);
dfs2 (1,0,0);
write (sum), LF;}