题意:每个顶点有一个权值,求权值乘积为立方数的路径个数。路径节点数可以为1。
解法:典型的树分治。
1.dfs求出通过中心的每个子树中路径长度集合,在已经遍历的子树中路径长度集合中找到每条路径的“补路径”个数,,即乘积为立方数的路径。累加起来即为结果
要点:
1.用a数组记录每个子树的路径长度集合
2.用map标记已遍历的子树路径长度及其个数
3.需要扩栈
#pragma comment(linker, "/STACK:102400000,102400000")
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
#include <map>
#define ll long long
#define INF 0x3f3f3f3f
using namespace std;
const ll maxn =50000+10;
ll n,K,pr[35];
struct Node{
ll num[35];
}node[maxn],d[maxn];
vector<ll> g[maxn];
ll son[maxn],maxs[maxn];
ll core;
bool vis[maxn];
ll ans;
ll sum;
map<ll,ll> ma;
vector<ll> a;
ll Power(ll a,ll b){
ll ans=1;
while(b){
if(b&1) { ans=ans*a; b--; }
a=a*a,b>>=1;
}
return ans;
}
void getcore(ll u,ll fa){
son[u]=1,maxs[u]=0;
for(ll i=0;i<g[u].size();i++){
int v=g[u][i];
if(v==fa||vis[v]) continue;
getcore(v,u);
son[u]+=son[v];
maxs[u]=max(maxs[u],son[v]);
}
maxs[u]=max(maxs[u],sum-son[u]);
if(maxs[u]<maxs[core]) core=u;
}
void dfs(ll u,ll fa){
ll tmpc=1.0;
ll x=1;
son[u]=1;
for(ll j=0;j<K;j++){
d[u].num[j]=((d[fa].num[j]+node[u].num[j])%3+3)%3;
x*=Power(pr[j],d[u].num[j]);
}
a.push_back(x);
//每次在计算路径乘积和互补路径长度时插入中心节点的值
for(ll i=0;i<K;i++){
ll tmp=((3-d[u].num[i]-node[core].num[i])%3+3)%3;
tmpc=tmpc*Power(pr[i],tmp);
}
ans+=ma[tmpc];
for(ll i=0;i<g[u].size();i++){
ll v=g[u][i];
if(v==fa||vis[v]) continue;
dfs(v,u);
son[u]+=son[v];
}
}
void work(ll u){
ma[1]=1; vis[u]=1;
for(ll i=0;i<g[u].size();i++){
ll v=g[u][i];
if(vis[v]) continue;
//不计入重心节点的值
for(ll j=0;j<K;j++) d[u].num[j]=0;
dfs(v,u);
for(ll j=0;j<a.size();j++) { ma[a[j]]++; }
a.clear();
}
ma.clear();
for(ll i=0;i<g[u].size();i++){
ll v=g[u][i];
if(vis[v]) continue;
sum=son[v]; core=0; getcore(v,u);
work(core);
}
}
int main(){
//freopen("a.txt","r",stdin);
while(scanf("%lld",&n)!=EOF){
scanf("%lld",&K);
for(ll i=0;i<K;i++) scanf("%lld",&pr[i]);
for(int i=1;i<=n;i++){
memset(node[i].num,0,sizeof(node[i].num));
memset(d[i].num,0,sizeof(d[i].num));
}
for(ll i=1;i<=n;i++) g[i].clear();
ans=0;
for(ll i=1;i<=n;i++){
ll tmp,cnt;
scanf("%lld",&tmp);
for(ll j=0;j<K;j++){
cnt=0;
if(tmp%pr[j]==0){
while(tmp%pr[j]==0){
cnt++;
tmp/=pr[j];
}
}
node[i].num[j]=cnt%3;
}
bool flag=true;
for(ll j=0;j<K;j++){
if(node[i].num[j]!=0) flag=false;
}
if(flag) ans++;
}
for(ll i=1;i<n;i++){
ll u,v; scanf("%lld%lld",&u,&v);
g[u].push_back(v),g[v].push_back(u);
}
core=0,maxs[core]=INF;
memset(vis,0,sizeof(vis));
ma.clear();
a.clear();
sum=n;
getcore(1,0);
work(core);
printf("%lld\n",ans);
}
return 0;
}