Tree Cutting
Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 262144/131072 K (Java/Others)
Total Submission(s): 1298 Accepted Submission(s): 501
Problem Description
Byteasar has a tree T with n vertices conveniently labeled with 1,2,...,n. Each vertex of the tree has an integer value vi.
The value of a non-empty tree T is equal to v1⊕v2⊕...⊕vn, where ⊕ denotes bitwise-xor.
Now for every integer k from [0,m), please calculate the number of non-empty subtree of T which value are equal to k.
A subtree of T is a subgraph of T that is also a tree.
Input
The first line of the input contains an integer T(1≤T≤10), denoting the number of test cases.
In each test case, the first line of the input contains two integers n(n≤1000) and m(1≤m≤210), denoting the size of the tree T and the upper-bound of v.
The second line of the input contains n integers v1,v2,v3,...,vn(0≤vi<m), denoting the value of each node.
Each of the following n−1 lines contains two integers ai,bi, denoting an edge between vertices ai and bi(1≤ai,bi≤n).
It is guaranteed that m can be represent as 2k, where k is a non-negative integer.
Output
For each test case, print a line with m integers, the i-th number denotes the number of non-empty subtree of T which value are equal to i.
The answer is huge, so please module 109+7.
Sample Input
2
4 4
2 0 1 3
1 2
1 3
1 4
4 4
0 1 3 1
1 2
1 3
1 4
Sample Output
3 3 2 3
2 4 2 3
题意:有一棵n个点的无根树,节点依次编号为1到n,其中节点i的权值为vi, 定义一棵树的价值为它所有点的权值的异或和。 现在对于每个[0,m)的整数k,请统计有多少T的非空连通子树的价值等于k。
思路:树形dp,每一个节点需要存储包含该节点的该节点子树的对应的异或和为k的方案数,对于父节点的更新,等于父节点原本的方案数加上子节点的方案数。采用fwt加速。
#include<iostream>
#include<stdio.h>
#include<vector>
#include<string.h>
using namespace std;
typedef long long ll;
const ll mod = 1e9+7;
const int maxm = 1300 ;
int n,m;
vector<int>g[1005];
int a[maxm];
int b[maxm];
ll dp[maxm][maxm];
ll ans[maxm];
ll tmp[maxm];
ll rev;
ll quick_mod(ll a,ll b)
{
ll ans = 1;
while(b)
{
if(b&1)ans = (ans * a)%mod;
b>>=1;
a = (a * a)%mod;
}
return ans;
}
void FWT(ll a[],int n)
{
for(int d=1;d<n;d<<=1)
for(int m=d<<1,i=0;i<n;i+=m)
for(int j=0;j<d;j++)
{
ll x=a[i+j],y=a[i+j+d];
a[i+j]=(x+y)%mod,a[i+j+d]=(x-y+mod)%mod;
//xor:a[i+j]=x+y,a[i+j+d]=x-y;
//and:a[i+j]=x+y;
//or:a[i+j+d]=x+y;
}
}
void UFWT(ll a[],int n)
{
for(int d=1;d<n;d<<=1)
for(int m=d<<1,i=0;i<n;i+=m)
for(int j=0;j<d;j++)
{
ll x=a[i+j],y=a[i+j+d];
a[i+j]=1LL*(x+y)*rev%mod,a[i+j+d]=(1LL*(x-y)*rev%mod+mod)%mod;
//xor:a[i+j]=(x+y)/2,a[i+j+d]=(x-y)/2;
//and:a[i+j]=x-y;
//or:a[i+j+d]=y-x;
}
}
void solve(ll a[],ll b[],int n)
{
FWT(a,n);
FWT(b,n);
for(int i=0;i<n;i++) a[i]=1LL*a[i]*b[i]%mod;
UFWT(a,n);
}
void dfs(int u,int p)
{
dp[u][a[u]]=1;
for(int i = 0;i<g[u].size();i++)
{
if(g[u][i]==p)continue;
dfs(g[u][i],u);
for(int j = 0;j<m;j++)
{
tmp[j] = dp[u][j];
}
solve(tmp,dp[g[u][i]],m);
for(int j = 0;j<m;j++)
{
dp[u][j] = (dp[u][j] + tmp[j])%mod;
}
}
for(int i = 0;i<m;i++)
{
ans[i] =(ans[i] + dp[u][i])%mod;
}
}
int main()
{
rev = quick_mod(2ll,mod-2);
int t;
scanf("%d",&t);
while(t--)
{
memset(dp,0,sizeof(dp));
memset(ans,0,sizeof(ans));
for(int i = 0;i<1005;i++)g[i].clear();
scanf("%d%d",&n,&m);
for(int i =1;i<=n;i++)scanf("%d",&a[i]);
int u,v;
for(int i = 0;i<n-1;i++)
{
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1,0);
for(int i = 0;i<m;i++)printf("%lld%c",ans[i],i==m-1?'\n':' ');
}
}