http://acm.hdu.edu.cn/showproblem.PHP?pid=5909
问题描述
Byteasar有一棵nn个点的无根树,节点依次编号为11到nn,其中节点ii的权值为v_ivi。
定义一棵树的价值为它所有点的权值的异或和。
现在对于每个[0,m)[0,m)的整数kk,请统计有多少TT的非空连通子树的价值等于kk。
一棵树TT的连通子树就是它的一个连通子图,并且这个图也是一棵树。
一个暴力的做法就是 树dp
- void dfs(int x,int fa)
- {
- dp[x][a[x]]=1; //与自己异或之后的值为a【x】的方案数
- for (int i=0; i<mp[x].size(); i++)
- {
- int v=mp[x][i];
- if(v==fa) continue;
- dfs(v,x);
- for (int i=0; i<m; i++)
- tmp[i]=dp[x][i];
- solve(dp[x],dp[v],m);//其中这里求得其实是 当前dp[x]的所有值与dp[v]的所有值 异或的结果, **
- for (int i=0; i<m; i++)
- dp[x][i]=(dp[x][i]+tmp[i])%mod;
- }
- for (int i=0;i<m;i++)
- ans[i]=(ans[i]+dp[x][i])%mod;
- }
**处 的solve,暴力两个for 是n^2的,如果用fwt加速,可以做到nlogn
复杂度从n^3变为n*n*logn
恩,就是这样咯
- #include<bits/stdc++.h>
- using namespace std;
- const int N=1e3+100,mod=1e9+7,rev=(mod+1)>>1;
- int a[N],dp[N][N],ans[N];
- vector <int >mp[N];
- int n,m;
- int tmp[N];
- void FWT(int *a,int n)
- {
- for(int d=1; d<n; d<<=1)
- for(int m=d<<1,i=0; i<n; i+=m)
- for(int j=0; j<d; j++)
- {
- int x=a[i+j],y=a[i+j+d];
- a[i+j]=(x+y)%mod,a[i+j+d]=(x-y+mod)%mod;
- }
- }
-
- void UFWT(int *a,int n)
- {
- for(int d=1; d<n; d<<=1)
- for(int m=d<<1,i=0; i<n; i+=m)
- for(int j=0; j<d; j++)
- {
- int x=a[i+j],y=a[i+j+d];
- a[i+j]=1LL*(x+y)*rev%mod,a[i+j+d]=(1LL*(x-y)*rev%mod+mod)%mod;
- }
- }
-
- void solve(int *a,int *b,int n)
- {
- FWT(a,n);
- FWT(b,n);
- for(int i=0; i<n; i++) a[i]=1LL*a[i]*b[i]%mod;
- UFWT(a,n);
- }
- void dfs(int x,int fa)
- {
- dp[x][a[x]]=1;
- for (int i=0; i<mp[x].size(); i++)
- {
- int v=mp[x][i];
- if(v==fa) continue;
- dfs(v,x);
- for (int i=0; i<m; i++)
- tmp[i]=dp[x][i];
- solve(dp[x],dp[v],m);
- for (int i=0; i<m; i++)
- dp[x][i]=(dp[x][i]+tmp[i])%mod;
- }
- for (int i=0;i<m;i++)
- ans[i]=(ans[i]+dp[x][i])%mod;
- }
-
-
- int main()
- {
- int t,u,v;
- cin>>t;
- while(t--)
- {
- cin>>n>>m;
- memset(dp,0,sizeof dp);
- memset(ans,0,sizeof ans);
- for (int i=1; i<=n; i++) mp[i].clear();
- for (int i=1; i<=n; i++) scanf("%d",&a[i]);
- for (int i=1; i<n; i++)
- {
- scanf("%d%d",&u,&v);
- mp[u].push_back(v);
- mp[v].push_back(u);
- }
- dfs(1,0);
- for (int i=0;i<m;i++)
- printf("%d%c",ans[i],i==m-1?'\n':' ');
-
- }
-
-
-
- return 0;
- }