Strange Memory
Once there was a rooted tree. The tree contained n nodes, which were numbered 1,…,n. The node numbered 1 was the root of the tree. Besides, every node i was assigned a number ai. Your were surprised to find that there were several pairs of nodes (i,j) satisfying ai⊕aj=alca(i,j), where ⊕ denotes the bitwise XOR operation, and lca(i,j) is the lowest common ancestor of i and j, or formally, the lowest (i.e. deepest) node that has both i and j as descendants.
Unfortunately, you cannot remember all such pairs, and only remember the sum of i⊕j for all different pairs of nodes (i,j) satisfying the above property. Note that (i,j) and (j,i) are considered the same here. In other words, you will only be able to recall i=1∑nj=i+1∑n[ai⊕aj=alca(i,j)](i⊕j).
You are assumed to calculate it now in order to memorize it better in the future.
Input
The first line contains a single integer n (2≤n≤105).
The second line contains n integers, a1,a2,…,an (1≤ai≤106).
Each of the next n−1 lines contains 2 integers u and v (1≤u,v≤n,u≠v), indicating that there is an edge between u and v. It is guaranteed that these edges form a tree.
Output
Print what you will memorize in the future.
Example
Input
6
4 2 1 6 6 5
1 2
2 3
1 4
4 5
4 6
Output
18
树上启发式合并,对每个要计算结果的节点,对每个以其儿子节点为根的子树的节点,每次先查找set中与其权值异或为要计算结果的节点的权值,再这些节点它放入set,后来发现set TLE了,用了个数组记录就A了
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e6+9;
vector<int>edge[maxn];
int val[maxn],Size[maxn],son[maxn],flag[maxn];
int num[maxn][21];
int vis[maxn];
int tmp[maxn];
int Pow[23];
long long ans;
void init()
{
Pow[0] = 1;
for(int i = 1; i<=20; i++)
{
Pow[i] = Pow[i-1]*2;
}
memset(num,0,sizeof(num));
}
void con_son(int x,int fa)
{
int max_num =-1;
Size[x] = 1;
for(int i = 0; i<edge[x].size(); i++)
{
int to = edge[x][i];
if(to==fa)
continue;
con_son(to,x);
if(Size[to]>max_num)
{
max_num = Size[to];
son[x] = to;
}
Size[x]+=Size[to];
}
}
void conclute(int x,int fa,int add,int lca)
{
if(add==1)
{
tmp[++tmp[0]] = x;
int num_tmp = val[x]^lca;
int x_tmp =x;
int now = 1;
// for(int i=0;i<=20;i++)
// printf("%d ",num[num_tmp][i]);
// printf("******%d\n",x_tmp);
for(int i=1;i<=20;i++)
{
int nn = x_tmp%2;
x_tmp/=2;
if(nn==0)
{
ans+=(long long )(Pow[now-1]*num[num_tmp][now]);
}
else
{
ans +=(long long )(Pow[now-1]*(num[num_tmp][0]-num[num_tmp][now]));
}
now++;
}
//cout<<ans<<endl;
}
else
{
vis[x]=0;
num[val[x]][0]--;
int x_tmp = x;
int now = 1;
while(x_tmp)
{
int nn =x_tmp%2;
x_tmp/=2;
if(nn==1)
{
num[val[x]][now]--;
}
now++;
}
}
for(int i = 0; i<edge[x].size(); i++)
{
int to = edge[x][i];
if(to==fa)
continue;
conclute(to,x,add,lca);
}
}
void dfs(int x,int fa,int save)
{
for(int i = 0; i<edge[x].size(); i++)
{
int to = edge[x][i];
if(to==fa||to==son[x])
continue;
dfs(to,x,0);
}
if(son[x])
{
dfs(son[x],x,1);
}
for(int i = 0; i<edge[x].size(); i++)
{
int to = edge[x][i];
if(to==fa||to==son[x])
continue;
tmp[0] = 0;
conclute(to,x,1,val[x]);
for(int j = 1; j<=tmp[0]; j++)
{
vis[tmp[j]] = 1;
num[val[tmp[j]]][0]++;
int x_tmp = tmp[j];
int now = 1;
while(x_tmp)
{
int nn =x_tmp%2;
x_tmp/=2;
if(nn==1)
{
num[val[tmp[j]]][now]++;
}
now++;
}
}
}
if(save ==0)
{
for(int i = 0; i<edge[x].size(); i++)
{
int to=edge[x][i];
if(to==fa)
continue;
conclute(to,x,-1,val[x]);
}
}
else
{
if(vis[x]==0)
{
num[val[x]][0]++;
int x_tmp = x;
int now = 1;
while(x_tmp)
{
int nn =x_tmp%2;
x_tmp/=2;
if(nn==1)
{
num[val[x]][now]++;
}
now++;
}
vis[x]=1;
}
}
}
int main()
{
init();
int i,j,m,n,a,b;
scanf("%d",&n);
for(i = 1; i<=n; i++)
{
scanf("%d",&val[i]);
}
for(i = 0; i<n-1; i++)
{
scanf("%d %d",&a,&b);
edge[a].push_back(b);
edge[b].push_back(a);
}
con_son(1,0);
dfs(1,0,1);
printf("%lld\n",ans);
return 0;
}