经典的男人八题的改编题目, 从边权到了点权, 而且每个节点有30个三进制的状态, 但其实做法还是和原来大同小异的在处理子问题时需要用到一个map, 复杂度为n * log(n) * log(n) * 30
#include <iostream>
#include <cstdio>
#include <queue>
#include <map>
#include <algorithm>
#include <cstring>
#include <cstdlib>
using namespace std;
typedef long long LL;
inline int readint() {
char c = getchar();
while (!isdigit(c)) c = getchar();
int x = 0;
while (isdigit(c)) {
x = x * 10 + c - '0';
c = getchar();
}
return x;
}
inline LL readlong() {
char c = getchar();
while (!isdigit(c)) c = getchar();
LL x = 0;
while (isdigit(c)) {
x = x * 10 + c - '0';
c = getchar();
}
return x;
}
const int N = 50005;
const int M = N << 1;
int n, k;
int fac[30], tar[30];
int tmp[30];
LL A[N];
int tmp2[30];
int val[N][30];
LL key[N];
map<LL, int> cnt;
map<LL, int>::iterator it;
#define fi first
#define se second
inline void turnto(int* t, LL a) {
fill(t, t + k, 0);
int id = k - 1;
while (a) {
t[id] = a % 3;
a /= 3;
id--;
}
}
inline LL value(int* t) {
LL res = 0;
for (int i = 0; i < k; i++) {
res += t[i];
if (i == k - 1) break;
res *= 3;
}
return res;
}
int head[N], next[M], to[M];
bool done[N];
int size[N], d[N];
LL ver[N];
int sub[N];
int E, tot, root, sz;
LL ans;
void init() {
for (int i = 1; i <= n; i++)
head[i] = -1, done[i] = 0;
E = 0;
ans = 0;
}
void add(int u, int v) {
to[E] = v, next[E] = head[u], head[u] = E++;
to[E] = u, next[E] = head[v], head[v] = E++;
}
LL count(LL* s, LL* e) {
LL res = 0;
cnt.clear();
int c = 0;
for (LL* i = s; i != e; i++) {
cnt[*i]++;
c++;
}
LL v1, v2;
while (cnt.size()) {
it = cnt.begin();
v1 = it->fi;
memset(tmp, 0, sizeof(tmp));
turnto(tmp, v1);
for (int i = 0; i < k; i++)
tmp[i] = (tar[i] + 3 - tmp[i]) % 3;
v2 = value(tmp);
if (v1 != v2) {
res += (it->se) * cnt[v2];
cnt.erase(it), cnt.erase(cnt.find(v2));
}
else {
cnt.erase(it);
res += (LL)(it->se) * (it->se - 1) / 2;
}
}
return res;
}
void gao(int u, int fa) {
sub[tot] = u;
int t = tot++;
int omax = 0;
for (int i = head[u]; i != -1; i = next[i]) {
int v = to[i];
if (!done[v] && v != fa) {
gao(v, u);
omax = max(omax, size[v]);
}
}
size[u] = tot - t;
if (size[u] > 1)
d[u] = omax;
}
int center(int u) {
tot = 0;
gao(u, 0);
if (tot == 1) return u;
int key = N;
int res;
for (int k = 0; k < tot; k++) {
int v = sub[k];
d[v] = max(tot - size[v], d[v]);
if (size[v] > 1 && d[v] < key) {
key = d[v];
res = v;
}
}
return res;
}
void dfs(int u, int fa) {
int t = sz;
if (fa) {
LL sum = 0;
for (int i = 0; i < k; i++) {
sum += tmp2[i];
if (i == k - 1) break;
sum *= 3;
}
ver[sz++] = sum;
if (value(tmp2) == value(tar)) {
ans++;
}
}
for (int i = head[u]; i != -1; i = next[i]) {
int v = to[i];
if (v != fa && !done[v]) {
for (int j = 0; j < k; j++)
tmp2[j] += val[v][j], tmp2[j] %= 3;
dfs(v, u);
for (int j = 0; j < k; j++) {
tmp2[j] -= val[v][j];
if (tmp2[j] < 0)
tmp2[j] += 3;
}
}
}
if (fa == root && u != root) {
ans -= count(ver + t, ver + sz);
}
}
void solve(int u) {
root = center(u);
if (tot == 1) {
done[u] = 1;
return;
}
sz = 0;
turnto(tmp, key[root]);
for (int i = 0; i < k; i++) {
tar[i] = (3 - tmp[i]) % 3;
}
memset(tmp2, 0, sizeof(tmp2));
dfs(root, 0);
ans += count(ver, ver + sz);
done[root] = 1;
for (int i = head[root]; i != -1; i = next[i])
if (!done[to[i]])
solve(to[i]);
}
int main() {
int SIZE = 256 << 20; // 256MB
char *p = (char*)malloc(SIZE) + SIZE;
__asm__("movl %0, %%esp\n" :: "r"(p) );
int u, v;
while (~scanf("%d", &n)) {
k = readint();
for (int i = 0; i < k; i++)
fac[i] = readint();
init();
for (int i = 1; i <= n; i++) {
A[i] = readlong();
memset(tmp, 0, sizeof(tmp));
for (int j = 0; j < k; j++) {
while (A[i] % fac[j] == 0) {
tmp[j]++;
A[i] /= fac[j];
}
tmp[j] %= 3;
val[i][j] = tmp[j];
}
key[i] = value(tmp);
if (key[i] == 0) ans++;
}
for (int i = 0; i < n - 1; i++) {
u = readint(), v = readint();
add(u, v);
}
solve(1);
printf("%I64d\n", ans);
}
return 0;
}