#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<map>
using namespace std;
typedef long long ll;
const int maxn=50005;
const int inf=0x3fffffff;
struct edge
{
int to,next;
}ee[50005<<1];
int e[50005],ecnt;
void addedge(int u,int v)
{
ee[ecnt].to=v;ee[ecnt].next=e[u];e[u]=ecnt++;
ee[ecnt].to=u;ee[ecnt].next=e[v];e[v]=ecnt++;
}
int pri[33],in[maxn][33];
int n,k;
map<ll,int> m1,m2;
bool del[maxn];
int size[maxn],opt[maxn],tnode[maxn],tns,all[maxn][33],as,id[maxn];
ll ans;
void dfs(int f,int u)
{
tnode[tns++]=u;
size[u]=1;opt[u]=0;
int i,v;
for(i=e[u];i!=-1;i=ee[i].next)
{
v=ee[i].to;
if(!del[v]&&v!=f)
{
dfs(u,v);
size[u]+=size[v];
opt[u]=max(opt[u],size[v]);
}
}
}
int get_root(int u)
{
tns=0;
dfs(-1,u);
int mi=inf,ans=-1,i;
for(i=0;i<tns;++i)
{
opt[tnode[i]]=max(opt[tnode[i]],size[u]-size[tnode[i]]);
if(opt[tnode[i]]<mi)
{
mi=opt[tnode[i]];
ans=tnode[i];
}
}
return ans;
}
void get_dis(int f,int u)
{
int fid,i,v;
if(f!=-1)
{
fid=id[f];
for(i=0;i<k;++i)
{
all[as][i]=(all[fid][i]+in[u][i])%3;
}
}
else
{
for(i=0;i<k;++i)
{
all[as][i]=in[u][i];
}
}
id[u]=as++;
for(i=e[u];i!=-1;i=ee[i].next)
{
v=ee[i].to;
if(!del[v]&&v!=f)
{
get_dis(u,v);
}
}
}
ll calc(int u)
{
int i,j,kk,v;
ll ret=0,s1,s2;
m1.clear();
bool flag=true;
s1=0;
for(i=0;i<k;++i)
{
if(in[u][i]!=0)
flag=false;
s1=(s1<<2)+in[u][i];
}
if(flag)
ret++;
m1[s1]=1;
for(i=e[u];i!=-1;i=ee[i].next)
{
v=ee[i].to;
if(del[v])
continue;
as=0;
get_dis(-1,v);
for(j=0;j<as;++j)
{
s2=0;
for(kk=0;kk<k;++kk)
{
s2=(s2<<2)+((3-all[j][kk])%3);
}
ret+=m1[s2];
}
for(j=0;j<as;++j)
{
s1=0;
for(kk=0;kk<k;++kk)
{
s1=(s1<<2)+((all[j][kk]+in[u][kk])%3);
}
m1[s1]++;
}
}
return ret;
}
ll solve(int u)
{
u=get_root(u);
ll ret,i,v;
del[u]=true;
ret=calc(u);
for(i=e[u];i!=-1;i=ee[i].next)
{
v=ee[i].to;
if(!del[v])
ret+=solve(v);
}
return ret;
}
int main()
{
int i,j,u,v;
ll tmp,tmp2;
while(scanf("%d",&n)!=EOF)
{
scanf("%d",&k);
for(i=0;i<k;++i)
scanf("%d",&pri[i]);
for(i=1;i<=n;++i)
{
scanf("%I64d",&tmp);
for(j=0;j<k;++j)
{
in[i][j]=0;
tmp2=tmp;
while(tmp2%pri[j]==0)
{
tmp2/=pri[j];
in[i][j]++;
}
in[i][j]%=3;
}
}
memset(e,-1,sizeof(e));ecnt=0;
for(i=1;i<n;++i)
{
scanf("%d%d",&u,&v);
addedge(u,v);
}
memset(del,false,sizeof(del));
printf("%I64d\n",solve(1));
}
return 0;
}
hdu4670
最新推荐文章于 2017-08-17 10:45:07 发布