题目地址
题意:求一棵树上子树异或和为[0,m)的个数各位多少。
思路:f[i][j]表示以i为根异或和为j的连通子树的个数,转移就是
f[i][j^k] = f[i][j^k]+f[i][j]*f[son][k].
这是个异或卷积,用fwt优化,复杂度o(nmlogm)。
这复杂度居然能过我也是没想到的
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <vector>
#include <queue>
#include <set>
#include <stack>
#include <time.h>
#include <map>
#include <algorithm>
#include <fstream>
//#include <unordered_map>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 1500 + 100;
const int INF = 0x7fffffff;
const ll mod = 1e9+7;
const ll mod1 = 998244353;
const ll base = 137;
const double Pi = acos(-1.0);
const int G = 3;
int v[maxn];
vector<int>edge[maxn];
int dp[maxn][maxn];
int f[maxn];
int g[maxn];
int n,m;
void FWTor(int *a,int type)
{
int i,j,k;
for(i=1;i<=n;i++)
for(j=0;j<(1<<n);j+=1<<i)
for(k=0;k<(1<<i-1);k++)
(a[j|(1<<i-1)|k]+=(a[j|k]*type+mod)%mod)%=mod;
}
void FWTand(int *a,int type)
{
int i,j,k;
for(i=1;i<=n;i++)
for(j=0;j<(1<<n);j+=1<<i)
for(k=0;k<(1<<i-1);k++)
(a[j|k]+=(a[j|(1<<i-1)|k]*type+mod)%mod)%=mod;
}
void FWTxor(int *a,long long type)
{
int i,j,k,x,y;
for(i=1;i<=n;i++)
for(j=0;j<(1<<n);j+=1<<i)
for(k=0;k<(1<<i-1);k++)
x=(a[j|k]+a[j|(1<<i-1)|k])*type%mod,
y=(a[j|k]-a[j|(1<<i-1)|k]+mod)*type%mod,
a[j|k]=x,a[j|(1<<i-1)|k]=y;
}
void Or(int *ans,int *a,int *b)
{
FWTor(a,1);
FWTor(b,1);
for(int i=0;i<(1<<n);i++) a[i]=1ll*a[i]*b[i]%mod;
FWTor(a,-1);
for(int i=0;i<(1<<n);i++) ans[i]=a[i];
}
void And(int *ans,int *a,int *b)
{
FWTand(a,1);
FWTand(b,1);
for(int i=0;i<(1<<n);i++) a[i]=1ll*a[i]*b[i]%mod;
FWTand(a,-1);
for(int i=0;i<(1<<n);i++) ans[i]=a[i];
}
void Xor(int *ans,int *a,int *b)
{
FWTxor(a,1);
FWTxor(b,1);
for(int i=0;i<(1<<n);i++) a[i]=1ll*a[i]*b[i]%mod;
FWTxor(a,(mod+1)>>1);
for(int i=0;i<(1<<n);i++) ans[i]=a[i];
}
int ans[maxn];
void dfs(int x,int pre)
{
dp[x][v[x]]=1;
n=(int)log2(m);
for(auto i:edge[x])
{
if(i==pre) continue;
dfs(i,x);
for(int j=0;j<(1<<n);j++) f[i]=g[i]=ans[i]=0;
for(int j=0;j<m;j++)
{
f[j]=dp[i][j];
g[j]=dp[x][j];
}
Xor(ans,f,g);
for(int j=0;j<m;j++)
{
dp[x][j]=(1ll*dp[x][j]+ans[j])%mod;
}
}
}
int sum[maxn];
int main()
{
int t;
cin>>t;
while(t--)
{
memset(dp,0,sizeof(dp));
int n;
scanf("%d%d",&n,&m);
for(int i=0;i<m;i++)sum[i]=0;
for(int i=1;i<=n;i++)
{
edge[i].clear();
scanf("%d",&v[i]);
}
for(int i=1;i<n;i++)
{
int l,r;
scanf("%d%d",&l,&r);
edge[l].push_back(r);
edge[r].push_back(l);
}
dfs(1,0);
for(int i=1;i<=n;i++)
{
for(int j=0;j<m;j++)
{
sum[j]=(1ll*sum[j]+dp[i][j])%mod;
}
}
for(int j=0;j<m;j++)
{
printf("%d%c",sum[j],j==m-1?'\n':' ');
}
}
//system("pause");
}