题意:给出一棵树(n<=5e4),每个点有一个颜色,颜色数k<=10(现场赛是7,可以直接状压暴力DP水过去)。现在要你统计合法路径的个数,一个合法路径可以用一个有序二元组(x,y)表示,代表从x到y的简单路径(xy可以相同),且这条路径上的点cover掉了所有的颜色。(1,2)和(2,1)是两条路径 (1,1)也是一条路径。
题解:先讲讲7的那个退化版的做法:7个颜色可以用7位压起来表示状态,2^7-1=127,因此可以用一个空间复杂度O(n*2^k)时间复杂度为O(n*2^k)的树形DP或者说dfs搞定,属于简单题。
再来看看k=10的版本,由于k=10,n*2^k的时间或者空间都无法接受。本题属于路径统计问题,是比较明显的点分治模型,那么在点分治的时候,我们可以从当前选定的根开始dfs出所有点的状态值,然后开始组合答案。考虑某个点的状态值是x,假如二进制形式是00110,而一个合法路径的状态值应该是11111,于是另一半路径的状态值必须保证是11XX1这样的形式,也就是说x中0的位另一半路径必须是1才行。稍微思考一下,可以想到一个枚举子集的法子,方法是,我们目标状态是11111,我们枚举它的子集,作为第一段路径的status,设为X,然后枚举X的子集,设为X*,意思是,让X*中的1作为合法路经中的1,然后显然另一半路径status必须是11111^X*,然后把X*和11111^X*的cnt乘起来就得到这部分答案。
然后是子树去重,假设根节点颜色是c,比如是2,那么我们在一个子树里面统计的非法路径要去掉两种,一种是在子树中即达到11111的路径,另一种是在子树中达到11101的路径,这个路径在上面的统计中由于重复走了c,就形成了11111,也被统计过了。求法和上面相同,只是目标值有所变化,方法依然是二维枚举子集。
关于起点终点相同的路径:在第一次统计的时候,比如dfs出某个点Z的status就是11111了,那么会把(Z,Z)这个路径也统计进去,但是我们并没有特意把他去掉,因为这个也属于起点终点分布在同一个子树中的情况,因此我们在子树中也会统计到它,这样一减,就没问题了。
Code:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 5e4+100;
const int maxk = 12;
int first[maxn],nxt[maxn*2],des[maxn*2],tot;
int a[maxn];
int bas[maxk];
int status[maxn];
bool vis[maxn];
long long cnt[1100];
int sz[maxn],ssz[maxn];
int k,n,maxstatus;
long long ans;
int Min,Minid;
const int INF = 0x3f3f3f3f;
void prework(){
bas[0]=1;
for (int i=1;i<=10;i++){
bas[i] =bas[i-1]*2;
}
}
void init(){
tot=ans=0;
memset(vis,0,sizeof vis);
memset(first,0,sizeof first);
}
inline void addEdge(int x_,int y_){
tot++;
des[tot] =y_;
nxt[tot]=first[x_];
first[x_]=tot;
}
void input(){
for (int i=1;i<=n;i++){
int flag;
scanf("%d",&flag);
a[i] =bas[flag-1];
}
for (int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
addEdge(u,v);addEdge(v,u);
}
maxstatus = (1<<k)-1;
}
void getSize(int node,int father){
sz[node]=1;
ssz[node]=0;
for (int t = first[node];t;t=nxt[t]){
int v = des[t];
if (v==father||vis[v])continue;
getSize(v,node);
sz[node]+=sz[v];
if (sz[v]>ssz[node])ssz[node] = sz[v];
}
}
void find_root(int node,int father,int root){
int val = max(sz[root]-sz[node],ssz[node]);
if (val<Min){
Min = val;
Minid = node;
}
for (int t = first[node];t;t=nxt[t]){
int v = des[t];
if (v==father||vis[v])continue;
find_root(v,node,root);
}
}
int getRoot(int node){
getSize(node,-1);
Min =Minid = INF;
find_root(node,-1,node);
return Minid;
}
void getStatus(int node,int father){
status[node] = status[father]|a[node];
cnt[status[node]] ++;
for (int t = first[node];t;t=nxt[t]){
int v = des[t];
if (v==father||vis[v])continue;
getStatus(v,node);
}
}
long long calc(int node,int S){
memset(cnt,0,sizeof cnt);
getStatus(node,0);
long long res =0;
for (int i=S;i;i = (i-1)&S){
if (cnt[i]){
for (int x = i;x;x=(x-1)&i){
res += cnt[i]*cnt[x^S];
}
res +=cnt[i]*cnt[0^S];
}
}
return res;
}
void solve(int node){
int root = getRoot(node);
ans+=calc(root,maxstatus);
vis[root] =true;
for (int t = first[root];t;t=nxt[t]){
int v = des[t];
if (vis[v])continue;
ans -=calc(v,maxstatus);
ans-= calc(v,a[root]^maxstatus);
solve(v);
}
}
int main(){
prework();
while (scanf("%d%d",&n,&k)!=EOF){
init();
input();
if (k==1){
cout<<1LL*n+1LL*n*(n-1)<<endl;
continue;
}
solve(1);
cout<<ans<<endl;
}
return 0;
}