题目大意:
就是给你一颗树,树上有各种权值,权值只有 K 种 k ∈ [ 1 , 10 ] K种k\in[1,10] K种k∈[1,10],问你有多少路径覆盖了这 K K K种权值, n ∈ [ 1 , 5 e 4 ] n\in[1,5e4] n∈[1,5e4],结果乘2输出,就是路径分方向。
解题思路:
因为 K K K很小,所以我们可以状态压缩去表示这 K K K个数的选择,如果这条路径上的或起来的结果为 ( 1 l l < < K ) − 1 (1ll<<K)-1 (1ll<<K)−1那么这条路径上就覆盖了 K K K个权值。
我们可以先预处理出所有两个数或的结果为
(
1
l
l
<
<
k
)
−
1
(1ll<<k)-1
(1ll<<k)−1的组合,之后暴力转移就可以了
复杂度
O
(
1023
∗
1023
∗
n
)
≈
O
(
1
e
7
)
O(1023*1023*n) \approx O(1e7)
O(1023∗1023∗n)≈O(1e7)
注意: K = = 1 K == 1 K==1的时候路径可以只有一个点特判一下
#pragma comment(linker,"/STACK:102400000,102400000")
#pragma GCC optimize(2)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 2e5 + 10;
vector<int> init[maxn];
vector<int> G[maxn];
int N, K, val[maxn];
ll ans;
//..................
int root, Mx, now_node;
int max_son[maxn], siz[maxn];
bool vis[maxn];
void getroot(int u, int fa) {
siz[u] = 1;
max_son[u] = 0;
for(int i = 0; i < G[u].size(); ++ i) {
int it = G[u][i];
if(vis[it] || fa == it) continue;
getroot(it,u);
max_son[u] = max(max_son[u],siz[it]);
siz[u] = siz[it] + siz[u];
}
max_son[u] = max(max_son[u],now_node - siz[u]);
if(max_son[u] < Mx) Mx = max_son[u], root = u;
}
//...................
unordered_map<ll,ll> mp;
void dfs(int u, int fa, int sum) {
if(mp.count(sum)) mp[sum] ++;
else mp[sum] = 1;
for(int i = 0; i < G[u].size(); ++ i) {
int it = G[u][i];
if(it == fa || vis[it]) continue;
dfs(it,u,sum | (1 << (val[it]-1)));
}
}
inline ll getans(int u, int kind) {
mp.clear();
dfs(u,0,(1 << (val[u]-1))|kind);
ll res = 0;
for(auto i : mp)
for(auto j : init[i.first]) {
if(mp.count(j)) {
if(i.first == j) res += (mp[j]-1)*i.second;
else res += mp[j]*i.second;
}
}
return res;
}
void Div(int u) {
vis[u] = 1;
ans += getans(u,0);
for(int i = 0; i < G[u].size(); ++ i) {
int it = G[u][i];
if(vis[it]) continue;
ans -= getans(it,(1ll << (val[u]-1)));
Mx = 1e9, root = 0;
now_node = siz[it];
getroot(it,0);
Div(root);
}
}
void Init(int Bit) {
for(int i = 0; i <= Bit; ++ i)
for(int j = 0; j <= Bit; ++ j)
if((i | j) == Bit)
init[i].push_back(j);
}
int main() {
while(scanf("%d%d",&N,&K) != EOF) {
Init((1ll << K)-1);
ans = 0;
for(int i = 1; i <= N; ++ i) vis[i] = 0, G[i].clear();
for(int i = 1; i <= N; ++ i)
scanf("%d",&val[i]);
for(int i = 1; i < N; ++ i) {
int l, r;
scanf("%d%d",&l,&r);
G[l].push_back(r);
G[r].push_back(l);
}
now_node = N;
Mx = 1e9, root = 0;
getroot(1,0);
Div(root);
printf("%lld\n",ans+(K==1)*N);
for(int i = 0; i <= (1ll << K)-1;++i) init[i].clear();
}
return 0;
}
/*
6 3
1 2 3 3 3 2
1 2
2 4
2 5
1 3
3 6
6 3
1 2 3 3 3 2
1 2
2 4
2 5
1 3
1 6
*/