第一次做dsu on tree维护异或的题目,思路还不是很活跃,再去找相关的几道题去做一下。
题意:给定你n个点,和他们对应的节点值ai。然后让你找出所有满足题目要求的式子的点对的异或和。(i,j)和(j,i)是一对相同的点对。
思路:很明显想要解决这个式子,我们有两个问题需要去解决,第一,如何去找到(i,j)这个点对值的LCA。第二,如何去找到满足条件的所有(i,j)点对值的亦或者。
我们这样考虑,这道题直接去找给出的i和j的LCA是不现实的,但是这道题并没有让我们去找到某一个确切的LCA是谁,而是要统计所有lca的答案。这样想,树上每一个有子节点的节点都会成为它的子节点的LCA,或者说每一个作为LCA的答案都是由它的不同子树贡献的。所以这道题其实就变成了统计某一个节点的子树贡献问题,这样就可以用dsu on tree来解决了。
第二个问题,如何快速的维护每一个满足要求的点对要求的异或值,直接暴力的添加和修改然后遍历肯定是会超时的。我们看到求的是什么?异或值,看到异或我们就应该去想能否和二进制产生联系。 这样看,我们能不能不去具体的找出每一个确切的i和j分别是什么,就能求出答案。考虑每一个这样的点对(i,j)对答案产生贡献时会有什么规律。
我们用一个数组cnt[x][y][z]来维护,当前值为x的节点的下标值所对应的二进制数第y位是z的节点有多少个。
我们假设他们LCA的节点值c,当前节点的节点值为a,那么很明显满足要求的节点的节点值就应该是c^a,如果当前这个点的下标值第i为1,那么只有当满足要求的那个点二进制下对应的第i位为0时,这一对点异或才会对答案产生2 ^ i大小的贡献。因此我们只需要用dsu去维护这个cnt数组的个数,然后去统计对于子树中每一个节点它能和插入的点产生的贡献为多少即可。
但是需要注意的是,dsu on tree一般的模板题中,统计答案和值插入对于某个节点的一个子树是同时完成的。但是这道题中,我们要统计不同子树之间的贡献,同一棵子树是不能产生贡献的,所以我们把统计函数和值插入函数修改分开,先统计完当前状态下的答案后,再把当前这个子树的节点值插入更新。
剩下的细节就在代码里:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 1e5 + 7;
const int N = 2e6 + 7;//这里开的比1e6大一些比较好因为异或后可能会变大
int head[MAXN],tot,siz[MAXN],son[MAXN],flag,a[MAXN];
int cnt[MAXN][22][2];ll ans;
struct node{
int next,to;
}edge[MAXN<<1];
void addedge(int u,int v){ edge[++tot].to = v;edge[tot].next = head[u];head[u] = tot; }
void pre_work(int u,int fa){
siz[u] = 1;
for(int i = head[u];i;i = edge[i].next){
int v = edge[i].to;
if(v == fa) continue;
pre_work(v,u);
siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
}
}
void modify(int u,int fa,int val){
for(int i = 0;i <= 20;i ++)
cnt[a[u]][i][(u>>i)&1] += val;//对应位置修改 以便统计对答案的贡献
for(int i = head[u];i;i = edge[i].next){
int v = edge[i].to;
if(v == fa) continue;
modify(v,u,val);
}
}
void get_sum(int u,int fa,int lca){//维护的一当前节点为lca的贡献 u代表的是当前lca节点的子树根节点
int aim = a[lca] ^ a[u];
if(aim < N/2){//符合条件的取值 大于1e6不可能能取到
for(int i = 0;i <= 20;i ++)
ans += 1ll * cnt[aim][i][!((u>>i)&1)] * (1<<i);//对异或值的贡献
}
for(int i = head[u];i;i = edge[i].next){
int v = edge[i].to;
if(v == fa) continue;
get_sum(v,u,lca);//lca不能变
}
}
void dfs(int u,int fa,int keep){
for(int i = head[u];i;i = edge[i].next){//把轻儿子先全部走一遍
int v = edge[i].to;
if(v == fa || v == son[u]) continue;
dfs(v,u,0);
}
if(son[u]) dfs(son[u],u,1),flag = son[u];//走一遍重儿子
//cal()一般的cal这里相当于分为两个函数实现 修改和统计答案
for(int i = head[u];i;i = edge[i].next){
int v = edge[i].to;
if(v == fa || v == flag) continue;
get_sum(v,u,u);
modify(v,u,1);//第一次先加入
}
for(int i = 0;i <= 20;i ++) cnt[a[u]][i][(u>>i)&1] += 1;//这里相比于一般的cal函数少了一开始的根节点u的值
flag = 0;
if(!keep) modify(u,fa,-1);//消除轻儿子影响
}
int main(){
int n;
scanf("%d",&n);
for(int i = 1;i <= n;i ++) scanf("%d",&a[i]);
int u,v;
for(int i = 1;i < n;i ++){
scanf("%d%d",&u,&v);
addedge(u,v);addedge(v,u);
}
pre_work(1,0);
dfs(1,0,0);
printf("%lld\n",ans);
return 0;
}