题意:给定一棵树,树的每个点有点权,定义2个点u和v之间的距离为u到v的路径上的点的点权的异或和。求全体点对(u,v):1<=u<=v<=n的距离和。
分析:考虑按位处理距离和。设有ans[i]个点对的距离的第i位为1,则距离和=ans[0]*2^0+ans[1]*2^1+...+ans[20]*2^20。从而问题转化为点权为0或1的情况。对于转化后的问题,我是用树分治处理的:对于子树u的所有点对路径,要么经过点u,要么不经过点u,经过点u的通过维护点权异或和为0,1的链数来进行统计,不经过点u的递归处理。注意应该把点权转化为向量来处理,这样只用跑一遍树分治,跑20遍会tle。
看完题解后发现不用树分治,直接树形dp就可以了。这是因为每个点只要访问一次就能得到我们需要的信息,所以不需要每层都把所以点都遍历一遍。
代码(树分治)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long LL;
const int maxn=1e5+10,maxl=21;
int n,a[maxn],b[maxn][maxl];
vector<int> G[maxn];
int sz[maxn],root,sum,minmaxs;
LL s[maxl][2],t[maxl][2],ans[maxl];
bool done[maxn];
void dfs_sz(int u,int fu)
{
sz[u]=1;
for (int i=0;i<G[u].size();i++)
{
int v=G[u][i];
if (v==fu||done[v]) continue;
dfs_sz(v,u);sz[u]+=sz[v];
}
}
void dfs_rt(int u,int fu)
{
int maxs=sum-sz[u];
for (int i=0;i<G[u].size();i++)
{
int v=G[u][i];
if (v==fu||done[v]) continue;
dfs_rt(v,u);maxs=max(maxs,sz[v]);
}
if (maxs<minmaxs) {minmaxs=maxs;root=u;}
}
void dfs(int u,int fu,int *o)
{
for (int l=0;l<maxl;l++) t[l][o[l]]++;
for (int i=0;i<G[u].size();i++)
{
int v=G[u][i];
if (v==fu||done[v]) continue;
int o1[maxl];
for (int l=0;l<maxl;l++) o1[l]=o[l]^b[v][l];
dfs(v,u,o1);
}
}
void solve(int u)
{
dfs_sz(u,-1);
minmaxs=maxn;sum=sz[u];
dfs_rt(u,-1);
u=root;done[u]=1;
//cout<<u<<endl;
memset(s,0,sizeof(s));
for (int i=0;i<G[u].size();i++)
{
int v=G[u][i];
if (done[v]) continue;
memset(t,0,sizeof(t));
dfs(v,u,b[v]);
for (int l=0;l<maxl;l++)
{
if (b[u][l])
{
ans[l]+=t[l][0];
ans[l]+=t[l][0]*s[l][0]+t[l][1]*s[l][1];
}
else
{
ans[l]+=t[l][1];
ans[l]+=t[l][0]*s[l][1]+t[l][1]*s[l][0];
}
s[l][0]+=t[l][0];s[l][1]+=t[l][1];
}
}
for (int i=0;i<G[u].size();i++)
{
int v=G[u][i];
if (done[v]) continue;
solve(v);
}
}
int main()
{
LL ret=0;
cin>>n;
for (int i=1;i<=n;i++) scanf("%d",&a[i]),ret+=a[i];
for (int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
G[u].push_back(v);G[v].push_back(u);
}
for (int i=1;i<=n;i++)
for (int j=0;j<maxl;j++)
b[i][j]=(a[i]>>j)&1;
solve(1);
for (int l=0;l<maxl;l++) ret+=(1<<l)*ans[l];
cout<<ret;
return 0;
}
代码(树形dp)
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn=1e5+10,maxl=21;
int n,a[maxn],f[maxn][maxl][2];
vector<int> G[maxn];
LL ans[maxl];
void dp(int u,int fu)
{
LL s[maxl][2];memset(s,0,sizeof(s));
for (int i=0;i<G[u].size();i++)
{
int v=G[u][i];
if (v==fu) continue;
dp(v,u);
for (int l=0;l<maxl;l++)
{
if (a[u]&(1<<l))
ans[l]+=f[v][l][0]+f[v][l][0]*s[l][0]+f[v][l][1]*s[l][1];
else
ans[l]+=f[v][l][1]+f[v][l][0]*s[l][1]+f[v][l][1]*s[l][0];
//if (l==10&&u==2) cout<<f[v][l][1]<<" "<<ans[l]<<endl;
s[l][0]+=f[v][l][0];
s[l][1]+=f[v][l][1];
}
}
for (int l=0;l<maxl;l++)
{
int d=(a[u]>>l)&1;
f[u][l][0]=s[l][0^d];f[u][l][1]=s[l][1^d];f[u][l][d]++;
//if (l==10&&u==3) cout<<d<<endl;
}
}
int main()
{
LL ret=0;
cin>>n;
for (int i=1;i<=n;i++) scanf("%d",&a[i]),ret+=a[i];
for (int i=1;i<n;i++)
{
int u,v;scanf("%d%d",&u,&v);
G[u].push_back(v);G[v].push_back(u);
}
dp(1,-1);
//for (int l=0;l<maxl;l++) cout<<ans[l]<<" ";
for (int l=0;l<maxl;l++) ret+=ans[l]*(1<<l);
cout<<ret;
return 0;
}