题目大意:给出一棵 n 个点组成的有根树,一号节点是根节点,现在要求实现 n * n 的公式:
题目分析:树上启发式合并,需要修改部分内部实现,如果可以想到树启的话,那么应该往子树上去靠拢,当每个点作为子树的根时,其可以作为 lca 然后去统计子树中可以匹配的 ( u , v ) 点对,这个题目因为 a[ i ] != 0,换句话说,点 u , v , lca( u , v ) 一定是互不相同的三个点,极大程度上简化了题目(因为三个点的形式一定是一个分叉的形状,不可能是链状的),换句话说,当 lca 确定后,可以枚举每个点作为点 u,然后去统计“除了点 u 到点 lca 的这条路径上的点之外,以 lca 为根节点的子树中,所有 a[ v ] = a[ u ] ^ a[ lca ] 的点 v,然后对 v^u 加和”,但如果每个点都去枚举的话,那么时间复杂度将会是 n * n 级别的,又因为如果每个点都枚举一遍的话,换句话说对于一个有贡献的 ( u , v ) 点对来说,点 u 和点 v 都会被枚举一遍从而形成了浪费
再考虑树启是关于轻重链剖分的,然后访问轻链的次数是 O( logn ) 级别的,因此可以只枚举轻链上的点的贡献,然后去寻找满足条件的点与其匹配即可,这样时间复杂度就下降到了 nlogn,具体实现就是,对于某一条轻链来说,先维护其贡献,然后再将其下标更新到数据结构中维护
下面讲该如何实现,首先上面简化的模型已经足够可以完成公式,但本题实际上是要求记录下标的异或和,比赛时没有多想,直接对每个权值维护了一个 set,用于记录当前权值下有多少个下标,维护答案的时候暴力枚举即可,本以为时间复杂度是 nlog^2n 的,TLE 的原因是因为 set 的那一层 log,于是想到每次增加和删除都是连续的一段,将 set 换成了 vector 便轻松将时间复杂度控制到了 nlogn 级别,交上去也真的 AC 了,赛后和潘学长还有zx学长讨论过后才发现,自己实际上写了个假算法。。因为如果暴力去维护下标的话,时间复杂度其实是 O( k * n * logn ),这里的 k 是,整棵树中出现次数最多的数字的出现次数,如果整棵树都是同一个数字的话,那么时间复杂度将退化成 O( n * n * logn ) 级别的。。然鹅感谢出题人,卑微菜鸡在这里给你磕头了,咚咚咚,多谢出题人没有刻意去卡这种数据,导致在随机数据下,常数 k 好像很小很小,甚至比 std 跑的还快。。
然后讲一下正解吧,因为我们需要维护的是下标的异或和,而对于异或而言,非常重要的一个性质就是,拆位之后每一位都相互独立,所以我们不妨对于下标的每一位都单独跑一次树启,最后将答案加和即可,对于某一位来说,假设 u == 1,对于匹配到的 v 来说,只有 v == 0 的位置才具有贡献,对于 u == 0 而言,同理只有 v == 1 才有贡献,所以类比于上一段的思路,对于每个权值 val 来说,上一段的做法维护的是一个 vector,里面储存着 a[ x ] == val 的 x,也就是下标,而本段的思路是,对于每个权值 val 来说,记录一下所有下标,在二进制下第 i 位中共出现了多少个 0 和多少个 1,只是换了一下实现思路而已,这样实现的时间复杂度是严格 O( nlogn * 20 ) 的,20 的意思是需要将下标拆成 20 位,因为 2^20 > 1e6 >= a[ i ]
代码:
//#pragma GCC optimize(2)
//#pragma GCC optimize("Ofast","inline","-ffast-math")
//#pragma GCC target("avx,sse2,sse3,sse4,mmx")
#include<iostream>
#include<cstdio>
#include<string>
#include<ctime>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<stack>
#include<climits>
#include<queue>
#include<map>
#include<set>
#include<sstream>
#include<cassert>
#include<bitset>
using namespace std;
typedef long long LL;
typedef unsigned long long ull;
const int inf=0x3f3f3f3f;
const int N=1e5+100;
vector<int>node[N],temp;
bool vis[N];
int deep[N],num[N],son[N],a[N];
int cnt[(1<<20)+100][25][2];//拆位
LL ans;
void update(int id,int val)
{
for(int i=0;i<20;i++)
cnt[a[id]][i][(id>>i)&1]+=val;
}
LL search(int id,int num)
{
LL ans=0;
for(int i=0;i<20;i++)
ans+=cnt[num][i][!((id>>i)&1)]*(1<<i);
return ans;
}
void dfs_son(int u,int fa,int dep)
{
deep[u]=dep;
son[u]=-1;
num[u]=1;
for(auto v:node[u])
{
if(v==fa)
continue;
dfs_son(v,u,dep+1);
num[u]+=num[v];
if(son[u]==-1||num[v]>num[son[u]])
son[u]=v;
}
}
void cal(int u,int fa,int lca)
{
temp.push_back(u);
ans+=search(u,a[u]^a[lca]);
for(auto v:node[u])
{
if(v==fa||vis[v])
continue;
cal(v,u,lca);
}
}
void del(int u,int fa)
{
update(u,-1);
for(auto v:node[u])
{
if(v==fa||vis[v])
continue;
del(v,u);
}
}
void dfs(int u,int fa,int keep)
{
for(auto v:node[u])
{
if(v==fa||v==son[u])
continue;
dfs(v,u,0);
}
if(son[u]!=-1)
{
dfs(son[u],u,1);
vis[son[u]]=true;
}
update(u,1);
for(auto v:node[u])
{
if(v==fa||vis[v])
continue;
cal(v,u,u);
for(auto it:temp)
update(it,1);
temp.clear();
}
if(son[u]!=-1)
vis[son[u]]=false;
if(!keep)
del(u,fa);
}
int main()
{
#ifndef ONLINE_JUDGE
// freopen("data.in.txt","r",stdin);
// freopen("data.out.txt","w",stdout);
#endif
// ios::sync_with_stdio(false);
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d",a+i);
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
node[u].push_back(v);
node[v].push_back(u);
}
dfs_son(1,-1,0);
dfs(1,-1,1);
printf("%lld\n",ans);
return 0;
}