Description
小A有一棵N个点的树,每个点都有一个小于2^20的非负整数权值。现在小A从树中随机选择一个点x,再随机选择一个点y(x、y可以是同一个点),并对从x到y的路径上所有的点的权值分别做and、or、xor运算,最终会求得三个整数。小A想知道,他求出的三个数的期望值分别是多少。
Data Constraint
对于20%的数据,1<=N<=1000;
对于另外20%的数据,N个点构成一条链;
对于全部的数据,1<=N<=100000,1<=T<=5。
Solution
看到二进制操作就要想到按位算贡献了
这里枚举二进制上的每一位,分别当成一棵只有0/1权值节点的树处理,转化成求树上节点权值and/or/xor起来等于1的路径条数,树形dp搞即可
一开始错是搞混了每个dp数组代表什么,再错就是多组数据没有清边集数组了。想扇自己
Code
#include <stdio.h>
#include <string.h>
#include <stack>
#include <queue>
#define rep(i,st,ed) for (int i=st;i<=ed;++i)
#define drp(i,st,ed) for (int i=st;i>=ed;--i)
#define fill(x,t) memset(x,t,sizeof(x))
#define max(x,y) ((x)>(y)?(x):(y))
typedef double db;
const int N=150005;
const int E=300005;
struct edge{int y,next;}e[E];
long long a[N][2],b[N][2],c[N][2],f[N][2],g[N][2],h[N][2],t[N],p[N][21],n;
int ls[N],edCnt=0;
int vis[N];
long double ans1=0,ans2=0,ans3=0;
int read() {
int x=0,v=1; char ch=getchar();
for (;ch<'0'||ch>'9';v=(ch=='-')?(-1):(v),ch=getchar());
for (;ch<='9'&&ch>='0';x=x*10+ch-'0',ch=getchar());
return x*v;
}
void addEdge(int x,int y) {
e[++edCnt]=(edge){y,ls[x]}; ls[x]=edCnt;
e[++edCnt]=(edge){x,ls[y]}; ls[y]=edCnt;
}
/*
ans1 f[x][0/1] and 0/1 a[x][0/1] tot_f
ans2 g[x][0/1] or 0/1 b[x][0/1] tot_g
ans3 h[x][0/1] xor 0/1 c[x][0/1] tot_h*/
int get(int x) {return (x)?(1):(0);}
void dp(int now,int lim) {
vis[now]++;
f[now][0]=f[now][1]=0;
g[now][0]=g[now][1]=0;
h[now][0]=h[now][1]=0;
a[now][0]=a[now][1]=0;
b[now][0]=b[now][1]=0;
c[now][0]=c[now][1]=0;
f[now][p[now][lim]]=g[now][p[now][lim]]=h[now][p[now][lim]]=1;
for (int i=ls[now];i;i=e[i].next) {
if (vis[e[i].y]!=lim) continue;
dp(e[i].y,lim);
if (!p[now][lim]) {
f[now][0]+=f[e[i].y][0]+f[e[i].y][1];
g[now][0]+=g[e[i].y][0];
g[now][1]+=g[e[i].y][1];
h[now][0]+=h[e[i].y][0];
h[now][1]+=h[e[i].y][1];
ans2+=b[now][1]*g[e[i].y][0]+b[now][0]*g[e[i].y][1]+b[now][1]*g[e[i].y][1];
ans3+=c[now][0]*h[e[i].y][1]+c[now][1]*h[e[i].y][0];
} else {
f[now][0]+=f[e[i].y][0];
f[now][1]+=f[e[i].y][1];
g[now][1]+=g[e[i].y][0]+g[e[i].y][1];
h[now][0]+=h[e[i].y][1];
h[now][1]+=h[e[i].y][0];
ans1+=f[e[i].y][1]*a[now][1];
ans2+=(g[e[i].y][0]+g[e[i].y][1])*(b[now][0]+b[now][1]);
ans3+=h[e[i].y][0]*c[now][0]+h[e[i].y][1]*c[now][1];
}
a[now][0]+=f[e[i].y][0]; a[now][1]+=f[e[i].y][1];
b[now][0]+=g[e[i].y][0]; b[now][1]+=g[e[i].y][1];
c[now][0]+=h[e[i].y][0]; c[now][1]+=h[e[i].y][1];
}
if (!p[now][lim]) {
ans2+=b[now][1];
ans3+=c[now][1];
} else {
ans1+=a[now][1]+0.5;
ans2+=b[now][0]+b[now][1]+0.5;
ans3+=c[now][0]+0.5;
}
}
void init() {
fill(p,0);
fill(ls,0);
edCnt=0;
n=read();
rep(i,1,n) {
t[i]=read(); vis[i]=1;
int tmp=t[i],j=1;
while (tmp) {
p[i][j++]=tmp&1;
tmp/=2;
}
}
rep(i,2,n) {
int x=read();
int y=read();
addEdge(x,y);
}
}
int main(void) {
int T=read();
while (T--) {
init();
double prt1=0,prt2=0,prt3=0;
for (int i=1;i<=20;i++) {
ans1=ans2=ans3=0;
dp(1,i);
prt1+=ans1*(1<<i-1)*2/(db)(n*n);
prt2+=ans2*(1<<i-1)*2/(db)(n*n);
prt3+=ans3*(1<<i-1)*2/(db)(n*n);
}
printf("%.3lf %.3lf %.3lf\n",prt1,prt2,prt3);
}
return 0;
}