problem
有一颗 n n n 个点的无根树 T T T,每个点有权值 v i v_i vi。定义一颗非空树的价值为它子树内的点权异或和。
请对于 ∀ k ∈ [ 0 , m ) \forall k\in[0,m) ∀k∈[0,m),计算 T T T 中价值为 k k k 的非空子树的数量。其中 T T T 的子树是 T T T 的一个子图,也是一颗树。
保证 m m m 可以表示为 2 k ( k ∈ N ) 2^k(k\in N) 2k(k∈N)。
数据范围: n ≤ 1000 n≤1000 n≤1000, 1 ≤ m ≤ 2 10 1≤m≤2^{10} 1≤m≤210, 0 ≤ v i < m 0\le v_i<m 0≤vi<m。
solution
我们设 f i , j f_{i,j} fi,j 表示在 i i i 的子树内价值为 j j j 的子树的个数。
有一个暴力的转移就是枚举儿子的第二维,再枚举自己的第二维,转移到异或值的那个位置。时间复杂度是 O ( n m 2 ) O(nm^2) O(nm2) 的。
大致的代码长这样( l i m lim lim 就是 m m m):
void dfs(int x,int fa){
f[x][val[x]]=1;
for(int i=first[x];i;i=nxt[i]){
int to=v[i];
if(to==fa) continue;
dfs(to,x);
for(int j=0;j<lim;++j)
for(int k=0;k<lim;++k)
temp[j^k]=add(temp[j^k],mul(f[x][j],f[to][k]));
for(int j=0;j<lim;++j)
f[x][j]=add(f[x][j],temp[j]),temp[j]=0;
}
for(int i=0;i<lim;++i) ans[i]=add(ans[i],f[x][i]);
}
仔细盯一下那个转移,发现就是异或卷积,那么可以用 FWT 优化。
时间复杂度 O ( n m log m ) O(nm\log m) O(nmlogm)。
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 2005
#define P 1000000007
using namespace std;
int n,lim;
int val[N],f[N][N],temp[N],ans[N];
int t,first[N],v[N],nxt[N];
int add(int x,int y) {return x+y>=P?x+y-P:x+y;}
int dec(int x,int y) {return x-y< 0?x-y+P:x-y;}
int mul(int x,int y) {return 1ll*x*y%P;}
int power(int a,int b,int ans=1){
for(;b;b>>=1,a=mul(a,a))
if(b&1) ans=mul(ans,a);
return ans;
}
void edge(int x,int y){
nxt[++t]=first[x],first[x]=t,v[t]=y;
}
void Clear(){
t=0;
memset(f,0,sizeof(f));
memset(ans,0,sizeof(ans));
memset(first,0,sizeof(first));
}
void FWT(int *f,int type){
for(int mid=1;mid<lim;mid<<=1){
for(int i=0;i<lim;i+=(mid<<1)){
for(int j=0;j<mid;++j){
int p0=f[i+j],p1=f[i+j+mid];
f[i+j]=add(p0,p1),f[i+j+mid]=dec(p0,p1);
}
}
}
if(type==-1){
int inv=power(lim,P-2);
for(int i=0;i<lim;++i) f[i]=mul(f[i],inv);
}
}
void dfs(int x,int fa){
f[x][val[x]]=1;
for(int i=first[x];i;i=nxt[i]){
int to=v[i];
if(to==fa) continue;
dfs(to,x);
for(int j=0;j<lim;++j) temp[j]=f[x][j];
FWT(temp,1),FWT(f[to],1);
for(int j=0;j<lim;++j) temp[j]=mul(temp[j],f[to][j]);
FWT(temp,-1);
for(int j=0;j<lim;++j) f[x][j]=add(f[x][j],temp[j]);
}
for(int i=0;i<lim;++i) ans[i]=add(ans[i],f[x][i]);
}
int main(){
int T,x,y;
scanf("%d",&T);
while(T--){
Clear();
scanf("%d%d",&n,&lim);
for(int i=1;i<=n;++i) scanf("%d",&val[i]);
for(int i=1;i<n;++i){
scanf("%d%d",&x,&y);
edge(x,y),edge(y,x);
}
dfs(1,0);
for(int i=0;i<lim;++i) printf("%d%c",ans[i],i==lim-1?'\n':' ');
}
return 0;
}