题链:https://codeforces.com/gym/102832/problem/F
题意:求 ,翻译一下就是求当 的和。
思路:做过类似的题,传送门 。看到LCA,可以考虑将问题转化到子树上。因为,以u为LCA的无序对(v1,v2)一定在u的子树中,其中v1,v2要属于u的不同儿子。
那么,我们可以维护子树中每个数的下标集合。在计算u子树的答案时,统计u的轻儿子v的贡献时,要先统计答案,再将v的子树里的数加上。
这样是为了避免将v的子树的无序对(v1,v2)统计为答案,虽然,但显然(v1,v2)的LCA不是u。
但是,如果直接用set、vector什么的维护每个数的下标,那么复杂度太高(lc大佬能过,tql,yyds)。看到异或这种二进制运算,就是引导你去往按位计算想。
我们要算的是所有符合的无序对的合,我们当然可以每位每位的算,也就是说对于当前数a[v1]的下标v1(假设v1的二进制位为10101)。
对于v1的第k位(从右往左数,最右边为第0位),我们可以看看u的子树中,有多少值为a[u]^a[v1]的下标的的第k位与它不同(异或嘛,不同为1)。假设有num个,那么对答案的贡献就是num*(1<<k)。
也就是对于v1(10101),对答案的贡献为 第0位为0的个数*(1<<0) + 第1位为1的个数*(1<<1) + 第2位为0的个数*(1<<2) + 第3位为1的个数*(1<<3) + 第4位为0的个数*(1<<4)。
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1e5+10;
const int M = 1e6+10;
int n,a[N],bit[N][20],limit;
struct node{
int to,nex;
}g[N<<1];
int head[N],cnt=0;
int sz[N],son[N],SON;
void getsz(int u,int fa){
sz[u]=1;
int maxx=-1;
for(int i=head[u];~i;i=g[i].nex){
int v=g[i].to;
if(v==fa) continue;
getsz(v,u);
sz[u]+=sz[v];
if(sz[v]>maxx) maxx=sz[v],son[u]=v;
}
}
void add(int u,int v){
g[cnt]=node{v,head[u]};
head[u]=cnt++;
}
int temp[N],cntt=0;
ll ans=0;
int num[M][20][2];//num[i][j][k]:表示当前子树中,值为i的下标中,第j位为k的个数。
void dfs1(int u,int fa,int root){
temp[++cntt]=u;
if((a[root]^a[u])<M){//有可能越界
for(int i=0;i<=limit;i++)
ans+=1LL*num[a[root]^a[u]][i][bit[u][i]^1]*(1<<i);
}
for(int i=head[u];~i;i=g[i].nex){
int v=g[i].to;
if(v==fa) continue;
dfs1(v,u,root);
}
}
void dfs2(int u,int fa){
for(int k=0;k<=limit;k++)
num[a[u]][k][bit[u][k]]--;
for(int i=head[u];~i;i=g[i].nex){
int v=g[i].to;
if(v==fa) continue;
dfs2(v,u);
}
}
void dfs(int u,int fa,bool keep){
for(int i=head[u];~i;i=g[i].nex){
int v=g[i].to;
if(v==fa||v==son[u]) continue;
dfs(v,u,0);//计算所有轻儿子的答案,不保存贡献。
}
if(son[u]) dfs(son[u],u,1);//计算重儿子的答案并保存贡献。
for(int i=head[u];~i;i=g[i].nex){
int v=g[i].to;
if(v==fa||v==son[u]) continue;
cntt=0;
dfs1(v,u,u);//要先统计答案,再将v的子树里的数加上。
for(int j=1;j<=cntt;j++){
for(int k=0;k<=limit;k++)
num[a[temp[j]]][k][bit[temp[j]][k]]++;
}
}
for(int i=0;i<=limit;i++)//计算以u为一个端点的无序对的答案,不过保证ai!=0,所以没有这种情况。
ans+=1LL*num[0][i][bit[u][i]^1]*(1<<i);
for(int k=0;k<=limit;k++)//加上u的贡献
num[a[u]][k][bit[u][k]]++;
if(!keep)
dfs2(u,fa);
}
int main(void){
scanf("%d",&n);
limit=ceil(log2(1.0*n));
for(int i=1;i<=n;i++){
head[i]=-1;
int x=i,j=0;//预处理下标的每一位
while(x){
bit[i][j++]=(x&1);
x>>=1;
}
}
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
getsz(1,0);
dfs(1,0,0);
printf("%lld\n",ans);
return 0;
}