参考http://www.cnblogs.com/zhangchengc919/p/6042601.html
涉及高维前缀和+
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn=5e4+5;
int n,k,fk,a[maxn],p[13];
int siz[maxn],son[maxn],vis[maxn],root,MAX,cnt,num[1030],d[maxn],cn=0,head[maxn];
LL ans=0;
struct Edge
{
int to,next;
}edge[2*maxn];
inline void add_edge(int from,int to)
{
edge[cn].to=to;
edge[cn].next=head[from];
head[from]=cn++;
}
inline void init()
{
cn=0;
fk=(1<<k)-1;
ans=0;
for(int i=0;i<=n;i++)vis[i]=0;
memset(head,-1,sizeof(head));
}
void get_siz(int cur,int fa)
{
siz[cur]=1;
son[cur]=0;
for(int i=head[cur];i!=-1;i=edge[i].next)
{
int x=edge[i].to;
if(x==fa||vis[x])continue;
get_siz(x,cur);
siz[cur]+=siz[x];
if(siz[x]>son[cur])son[cur]=siz[x];
}
}
void find_root(int cur,int fa,int rt)
{
if(siz[rt]-siz[cur]>son[cur])son[cur]=siz[rt]-siz[cur];
if(son[cur]<MAX)MAX=son[cur],root=cur;
for(int i=head[cur];i!=-1;i=edge[i].next)
{
int x=edge[i].to;
if(x==fa||vis[x])continue;
find_root(x,cur,rt);
}
}
void get_state(int cur,int fa,int sta)
{
d[cnt++]=sta;
num[sta]++;
for(int i=head[cur];i!=-1;i=edge[i].next)
{
int x=edge[i].to;
if(vis[x]||x==fa)continue;
get_state(x,cur,a[x]|sta);
}
}
LL cal(int cur,int sta)
{
cnt=0;
memset(num,0,sizeof(num));
get_state(cur,0,sta);
for(int i=0;i<k;i++)
{
for(int j=fk;j>=0;j--)
{
if(!(p[i]&j))num[j]+=num[j|p[i]];
}
}
LL ret=0;
for(int i=0;i<cnt;i++)ret=ret+num[fk^d[i]];
return ret;
}
void dfs(int cur)
{
MAX=n;
get_siz(cur,0);
find_root(cur,0,cur);
int Root=root;
ans=ans+cal(root,a[Root]);
vis[root]=1;
for(int j=head[Root];j!=-1;j=edge[j].next)
{
int x=edge[j].to;
if(vis[x])continue;
ans=ans-cal(x,(a[x]|a[Root]));
dfs(x);
}
}
int main()
{
p[0]=1;for(int i=1;i<=11;i++)p[i]=p[i-1]*2;
while(scanf("%d%d",&n,&k)!=EOF)
{
init();
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
a[i]=p[a[i]-1];
}
int u,v;
for(int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
add_edge(u,v);
add_edge(v,u);
}
dfs(1);
printf("%lld\n",ans);
}
return 0;
}
自己的不知道什么地方错了的代码:
#include <bits/stdc++.h>
using namespace std;
#define N 50050
#define LL long long
int p[15];
int a[N];
int n,k;
int d[N];
int fst[N],nxt[N<<1],vv[N<<1],e;
int vis[N];
int sz[N],son[N],mx,fk,root,num[N];
int cnt;
LL ans;
void init(){
memset(fst,-1,sizeof(fst));e=0;
ans=0;
fk=(1<<n)-1;
cnt=0;
}
void add(int u,int v){
nxt[e]=fst[u];vv[e]=v;fst[u]=e++;
}
void get_root(int u,int fa,int rt){
if(sz[rt]-sz[u]>son[u])son[u]=sz[rt]-sz[u];
if(son[u]<mx)mx=son[u],root=u;
for(int i=fst[u];~i;i=nxt[i]){
int v=vv[i];
if(v==fa||vis[v])continue;
get_root(v,u,rt);
}
}
void get_sz(int u,int fa){
sz[u]=1;
son[u]=0;
for(int i=fst[u];~i;i=nxt[i]){
int v=vv[i];
if(v==fa||vis[v])continue;
get_sz(v,u);
sz[u]+=sz[v];
if(sz[v]>son[u])son[u]=sz[v];
}
}
void get_sta(int cur,int fa,int sta){
d[cnt++]=sta;
num[sta]++;
for(int i=fst[cur];~i;i=nxt[i]){
int v=vv[i];
if(vis[v]||v==fa)continue;
get_sta(v,cur,sta|a[v]);
}
}
LL cal(int u,int sta){
cnt=0;
memset(num,0,sizeof(num));
get_sta(u,u,sta);
for(int i=0;i<k;++i){
for(int j=fk;j>=0;--j){
if(!((1<<i)&j))num[j]+=num[j|(1<<i)];
}
}
LL ret=0;
for(int i=0;i<cnt;++i)ret+=num[fk^d[i]];
return ret;
}
void dfs(int u,int fa){
mx=n;
get_sz(u,u);
get_root(u,u,a[u]);
ans+=cal(u,a[u]);
vis[root]=1;
//cout<<ans<<endl;
cout<<root<<endl;
for(int i=fst[u];~i;i=nxt[i]){
int v=vv[i];
if(v==fa||vis[v])continue;
ans-=cal(v,a[u]|a[v]);
dfs(v,u);
}
}
int main(){
//freopen("in.txt","r",stdin);
p[0]=1;for(int i=1;i<=11;++i)p[i]=p[i-1]<<1;
while(~scanf("%d%d",&n,&k)){
init();
for(int i=1;i<=n;++i)scanf("%d",&a[i]),a[i]=p[a[i]-1];
for(int i=1;i<n;++i){
int u,v;
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
dfs(1,1);
printf("%lld\n",ans);
}
}