题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4670
算法:基于点的树链分治
思路:
//树基于点的分治算法,可以参见国家集训队论文:2009年漆子超《分治算法在树的路径问题中的应用》
/*
000111222
+ 012012012
----------------------
= 012120201
*/
//Mul[u,v] 等于点u到点v的乘积中,各个素数的个数对3取余后的一个数列
//对于每一个Mul[u, S], 求Mul(S, v]+Mul[u,S]=0的个数(S为分治点)
//用map记录下所有Mul[S,u]的状态(用longlong保存3进制数),再用一个map保存取出分治点S以后(S,v]的状态,对于第一个map中的每一个状态,求在第二个map中是否存在它的互补状态
注意:
//第一次遇到手动扩栈的题,检查了好长时间
//素数有可能超过int
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<cstdio>
#include<cstring>
#include<map>
using namespace std;
#define LL __int64
const int M = 50010;
LL pri[35];//保存素数
int Num;//素数的个数
int num[M][30];//每个数的化简
struct Edge {
int v, next;
} edge[M << 1];
int head[M], E;//邻接表参数
LL ans = 0;//保存最终结果
int mmax, root, size[M], mx[M], vis[M];//求重点的参数
int tmp_cnt;//计算类似8,27,125这样单独成区间的区间个数
int dis[M][30];//保存区间[S,i]中第j个素数的个数dis[i][j]
map<LL, LL> er, mp;//er是计算3进制下需要匹配的数的个数,mp是计算分治点S到目标点Q,Mul[S,Q]在3进制下得到的数
map<LL, LL>::iterator it;
void init() {
ans = E = 0;
memset(vis, 0, sizeof(vis));
memset(head, -1, sizeof(head));
memset(num, 0, sizeof(num));
}
void add_edge(int s, int v) {
edge[E].v = v;
edge[E].next = head[s];
head[s] = E++;
}
void dfs_uu(int u, int fa) {
size[u] = 1;
mx[u] = 0;
for (int i = head[u]; i != -1; i = edge[i].next) {
int v = edge[i].v;
if (v == fa || vis[v])continue;
dfs_uu(v, u);
size[u] += size[v];
if (size[v] > mx[u]) mx[u] = size[v];
}
}
void dfs_u(int u, int fa, int r) {
if (size[r] - size[u] > mx[u]) mx[u] = size[r] - size[u];
if (mmax > mx[u]) { mmax = mx[u]; root = u;}
for (int i = head[u]; i != -1; i = edge[i].next) {
int v = edge[i].v;
if (v == fa || vis[v]) continue;
dfs_u(v, u, r);
}
}//以上两个函数求重点
void dfs_dis(int u, int fa, int r) {
LL tmp = 0, sum = 0, sum_nee = 0;
for (int i = 0; i < Num; i++) {
dis[u][i] = dis[fa][i] + num[u][i];
if (dis[u][i] > 2) dis[u][i] -= 3;
tmp = 3 - dis[u][i] + dis[r][i]; if (tmp >= 3) tmp -= 3;
sum = sum * 3 + tmp;
sum_nee = sum_nee * 3 + dis[u][i];
}
if (sum == sum_nee) tmp_cnt++;
er[sum]++;
mp[sum_nee]++;
for (int i = head[u]; i != -1; i = edge[i].next) {
int v = edge[i].v;
if (v == fa || vis[v])
continue;
dfs_dis(v, u, r);
}
}
LL cala(int u, int fa) {
mp.clear(); er.clear();
LL tt = 0;
tmp_cnt = 0;
if (fa != 0) for (int i = 0; i < Num; i++) dis[fa][i] = num[fa][i];
dfs_dis(u, fa, fa == 0 ? u : fa);
for (it = er.begin(); it != er.end(); it++) {
if (mp.find(it->first) != mp.end())
tt += it->second * mp[it->first];
}
return (tt - tmp_cnt) / 2 + tmp_cnt;
}//以上两个【核心】函数求分治区间中满足条件的点对~
void solve(int u) {
int rt = 0;
mmax = 123456;
root = u;
dfs_uu(u, -1);
dfs_u(u, -1, u);
ans += cala(root, 0);
vis[root] = 1;
rt = root;
for (int i = head[rt]; i != -1; i = edge[i].next) {
int v = edge[i].v;
if (vis[v]) continue;
ans -= cala(v, rt);
solve(v);
}
}
int main() {
int n;
while (scanf("%d", &n) != EOF) {
int i, j;
LL val;
init();
scanf("%d", &Num);
for (i = 0; i < Num; i++) {
scanf("%I64d", &pri[i]);
}
for (i = 1; i <= n; i++) {
scanf("%I64d", &val);
for (j = 0; j < Num; j++) {
while (val % pri[j] == 0) {
num[i][j]++;
val = val / pri[j];
if (num[i][j] > 2)
num[i][j] -= 3;
}
}
}
int a, b;
for (i = 0; i < n - 1; i++) {
scanf("%d%d", &a, &b);
add_edge(a, b);
add_edge(b, a);
}
solve(1);
printf("%I64d\n", ans);
}
return 0;
}