题目:点击打开链接
给你一棵n个节点的树,一共有k种颜色,每个节点有一种颜色,求a到b的路径中经过了所有颜色的路径的对数,(a,b)和(b,a)都要算
思路:树分治+状态压缩+枚举子集
代码:
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<string>
#include<vector>
#include<map>
#include<set>
#include<queue>
#include<stack>
#include<list>
#include<numeric>
using namespace std;
#define PI acos(-1.0)
#define LL long long
#define ULL unsigned long long
#define INF 0x3f3f3f3f
#define mm(a,b) memset(a,b,sizeof(a))
#define PP puts("*********************");
template<class T> T f_abs(T a){ return a > 0 ? a : -a; }
template<class T> T gcd(T a, T b){ return b ? gcd(b, a%b) : a; }
template<class T> T lcm(T a,T b){return a/gcd(a,b)*b;}
// 0x3f3f3f3f3f3f3f3f
// 0x3f3f3f3f
const int maxn=5e4+50;
struct Node{
int v,next;
}edge[2*maxn];
int head[maxn],tol;//节点从1开始计数
int son[maxn],f[maxn],vis[maxn];
//son表示以i为根的子树大小
//f数组表示以u为根的最大子树的大小
//vis表示该节点是不是已经作为重心了
int d[maxn],arr[maxn];
int cntv,root,K,ALL;
LL ans,num[1050],sum[1050];
vector<int> g[1050];
void init(){
mm(head,-1);tol=0;
mm(vis,0);
f[0]=INF;//一定要初始化成一个很大的值
}
void addedge(int u,int v){
edge[tol].v=v;
edge[tol].next=head[u];
head[u]=tol++;
}
void getroot(int u,int fa){//寻找重心
son[u]=1,f[u]=0;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v==fa||vis[v]) continue;
getroot(v,u);
son[u]+=son[v];
f[u]=max(f[u],son[v]);
}
f[u]=max(f[u],cntv-son[u]);//sum表示当前树的大小
if(f[u]<f[root]) root=u;//更新当前重心
}
void getdepth(int u,int fa){
num[d[u]]++;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v==fa||vis[v]) continue;
d[v]=(d[u]|arr[v]);
getdepth(v,u);
}
}
void cal(int u,int w){
mm(num,0);
d[u]=w;
getdepth(u,0);
}
void solve(int u){//计算以u为重心的树
vis[u]=1;
mm(sum,0);
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(vis[v]) continue;
cal(v,(arr[u]|arr[v]));
ans+=num[ALL];
for(int i=0;i<=ALL;i++){
if(num[i]==0)
continue;
for(int j=0;j<g[i].size();j++){
int v=g[i][j];
ans+=num[i]*sum[ALL&(~v)];
}
}
for(int i=0;i<=ALL;i++)
sum[i]+=num[i];
}
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(vis[v]) continue;
cntv=son[v];
root=0;
getroot(v,0);
solve(root);
}
}
int main(){
// freopen("D:\\input.txt","r",stdin);
// freopen("D:\\output.txt","w",stdout);
int n,u,v;
while(~scanf("%d%d",&n,&K)){
ALL=(1<<K)-1;
for(int i=0;i<=ALL;i++)
g[i].clear();
for(int i=0;i<=ALL;i++)
for(int j=0;j<=ALL;j++)
if((i&j)==j)
g[i].push_back(j);
init();
for(int i=1;i<=n;i++){
scanf("%d",&arr[i]);arr[i]--;
arr[i]=(1<<(arr[i]));
}
for(int i=1;i<n;i++){
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
if(K==1){
printf("%lld\n",(LL)n*n);
continue;
}
cntv=n;
root=0;//初始化根
getroot(1,0);
ans=0;
solve(root);
printf("%lld\n",ans*2);
}
return 0;
}