统计路径条数,想到点分治。问你包含所有颜色,想到状压。于是思路直接就出来了。 算了一下,时间复杂度大概为
O
(
n
log
n
×
2
10
)
O(n\log n\times 2^{10})
O(nlogn×210),限时
5000
M
s
5000Ms
5000Ms 貌似可以直接冲。
某个树找到重心后,我们按一个一个子树去遍历。 其中每个节点的
d
e
p
[
i
]
dep[i]
dep[i] 表示 该节点到根节点的路径上,颜色的和。 那么统计答案,我们就是要去找到之前的桶
b
i
g
[
S
]
big[S]
big[S],满足
S
∣
T
=
(
1
<
<
k
)
−
1
S|T=(1<<k)-1
S∣T=(1<<k)−1 答案就是这样去累加
∑
b
i
g
[
S
l
e
g
a
l
]
\sum big[S_{\ legal}]
∑big[Slegal]
我们这里只需要一个小特判,就不需要去容斥做了。 对于上述的计算,我们算出来的点对是子树间的两两点对。 然而
(
u
,
v
)
(u,v)
(u,v) 和
(
v
,
u
)
(v,u)
(v,u) 应该算两种,所以我们这些对之间都要乘以2。 我们唯一没有去算的就是单拿根节点的这种方案,明显只有
k
=
1
k=1
k=1 的时候,需要多累加
n
n
n 的方案
代码
时间复杂度:
O
(
n
log
n
×
2
10
)
O(n\log n\times 2^{10})
O(nlogn×210)
T
i
m
e
(
M
s
)
:
639
/
5000
Time(Ms):639/5000
Time(Ms):639/5000
/*
_ __ __ _ _
| | \ \ / / | | (_)
| |__ _ _ \ V /__ _ _ __ | | ___ _
| '_ \| | | | \ // _` | '_ \| | / _ \ |
| |_) | |_| | | | (_| | | | | |___| __/ |
|_.__/ \__, | \_/\__,_|_| |_\_____/\___|_|
__/ |
|___/
*/constint MAX =5e4+50;int n,k,sum,root;
ll ans;
vector<int>V[MAX],SS[MAX];int sz[MAX],ff[MAX];
bool vis[MAX];voidfind_rt(int x,int f){/// 点分治求重心
sz[x]=1;ff[x]=0;for(auto it : V[x]){if(it == f || vis[it])continue;find_rt(it,x);
sz[x]+= sz[it];
ff[x]=max(ff[x],sz[it]);}
ff[x]=max(ff[x],sum - sz[x]);if(ff[x]< ff[root])root = x;}int sml[MAX],big[1300],tmp[MAX],col[MAX],dep[MAX];voidget_col(int x,int f){/// 求颜色 dep[i]
sml[++sml[0]]= dep[x];for(auto it : V[x]){if(it == f || vis[it])continue;
dep[it]= dep[x]|(1<<col[it]);get_col(it,x);}}voidcal(int x){
big[0]=1;
tmp[0]=0;for(auto it : V[x]){if(vis[it])continue;
sml[0]=0;
dep[it]=(1<<col[it])|(1<<col[x]);/// 从根节点x出发get_col(it,x);for(int i =1;i <= sml[0];++i){for(auto j : SS[sml[i]]){
ans += big[j];/// 只算合法的答案}}for(int i =1;i <= sml[0];++i){
tmp[++tmp[0]]= sml[i];
big[sml[i]]++;}}for(int i =1;i <= tmp[0];++i)
big[tmp[i]]--;}voidpdc(int x){
vis[x]=1;cal(x);/// 只用算从根除法的贡献即可for(auto it : V[x]){if(vis[it])continue;
sum = sz[it];
root =0;find_rt(it,it);pdc(root);}}intmain(){while(~scanf("%d%d",&n,&k)){for(int i =0;i <(1<<k);++i)SS[i].clear();for(int i =0;i <(1<<k);++i){for(int j =0;j <(1<<k);++j){if((i | j)==(1<<k)-1){
SS[i].push_back(j);/// 事先枚举合法的集合}}}
ans =0;for(int i =1;i <= n;++i){scanf("%d",&col[i]);col[i]--;
V[i].clear();
vis[i]= false;}for(int i =1;i < n;++i){int ta,tb;scanf("%d%d",&ta,&tb);
V[ta].push_back(tb);
V[tb].push_back(ta);}
sum = n;
root=0;ff[0]= n;find_rt(1,1);pdc(root);
ans = ans *2;/// 点对答案乘以2if(k ==1)ans = ans + n;/// 简答特判即可回避容斥printf("%lld\n",ans);}return0;}