题意:给定N个点,K个种类,每个点都有自己的种类,问有多少点对可以覆盖全部的种类
看到k的范围很小,不妨利用状压的思想,将每个种类的点转换为1<<a[i],即寻找当前的和是否达到1<<k-1,如果达到,说明k种齐全。
这里有一个特殊的处理技巧,可以枚举子集的所有状态。
for (int sub = S; sub; sub = (sub - 1) & S) {
// sub 为 S 的子集
}
我们同样可以利用hash的思想,因为a | b = c ,我们可以转化为 c ^ b = a。
记得k=1时特判处理。
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>
#include <stack>
#include <cmath>
#include <bitset>
#include <map>
using namespace std;
//#define ACM_LOCAL
typedef long long ll;
typedef long double ld;
typedef pair<int, int> PII;
const int N = 1e5 + 5;
const int INF = 0x3f3f3f3f;
const int MOD = 1e6 + 3;
int n, m, cnt, h[N], rt, sz[N], mx[N], vis[N], sum, k, a[N];
int d[N], dep[N];
int path[1<<10];
ll ans;
struct edge{
int to, next;
}e[N<<1];
void add(int u, int v) {
e[cnt].to = v;
e[cnt].next = h[u];
h[u] = cnt++;
}
void getroot(int x, int fa) {
sz[x] = 1, mx[x] = 0;
for (int i = h[x]; ~i; i = e[i].next) {
int y = e[i].to;
if (y == fa || vis[y]) continue;
getroot(y, x);
sz[x] += sz[y];
mx[x] = max(mx[x], sz[y]);
}
mx[x] = max(mx[x], sum - sz[x]);
if (mx[x] < mx[rt]) rt = x;
}
void getd(int x, int fa, int now) {
d[++d[0]] = now;
for (int i = h[x]; ~i; i = e[i].next) {
int y = e[i].to;
if (y == fa || vis[y]) continue;
getd(y, x, now | (1<<a[y]));
}
}
ll cal(int x, int now) {
ll res = 0;
d[0] = 0;
memset(path, 0, sizeof path);
getd(x, -1, now);
for (int i = 1; i <= d[0]; i++) path[d[i]]++;
for (int i = 1; i <= d[0]; i++) {
path[d[i]]--;
res += path[(1<<k)-1];
for (int j = d[i]; j; j = (j-1) & d[i]) {
res += path[((1<<k)-1)^j];
}
path[d[i]]++;
}
return res;
}
void work(int x) {
vis[x] = 1;
ans += cal(x, 1<<a[x]);
for (int i = h[x]; ~i; i = e[i].next) {
int y = e[i].to;
if (vis[y]) continue;
ans -= cal(y, (1<<a[x]) | (1<<a[y]));
sum = sz[y], rt = 0;
getroot(y, -1);
work(rt);
}
}
void solve () {
while (~scanf("%d %d", &n, &k)) {
memset(h, -1, sizeof h);
memset(vis, 0, sizeof vis);
cnt = 0;
for (int i = 1; i <= n; i++) scanf("%d", &a[i]), a[i]--;
for (int i = 1; i < n; i++) {
int x, y;
scanf("%d %d", &x, &y);
add(x, y);
add(y, x);
}
mx[0] = INF, rt = 0, sum = n, ans = 0;
getroot(1, -1);
work(rt);
if (k == 1) printf("%lld\n", 1ll*n*n);
else printf("%lld\n", ans);
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
#ifdef ACM_LOCAL
freopen("input", "r", stdin);
freopen("output", "w", stdout);
#endif
solve();
}