题意:给出一颗树,每个点有一个权值,求所有路径中,路径经过的点的乘积是完全立方数的路径个数。点权由给定的素数的乘积组成,素数不超过30。
思路:如果是求和之类的还是很好求的,但是这个是乘积,所幸的是,素数最多只有30个,而且我们所要知道的信息只有每个素数对3取余之后的结果,因此,可以考虑用一个二进制数来表示从某一个点到当前根节点的乘积,两位表示一个素数对3取余的结果。用map来计数就行了。剩下的就比较好做了,每次找到根以后,计算根到每颗子树的乘积的值,用这个值可以算出互补的值,也就是说,这两个数的每个素数加起来以后对3取余等于0,然后map直接找有多少个就行了。
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<map>
#include<queue>
#include<stack>
#include<set>
#include<cmath>
#include<vector>
#include<bitset>
#define inf 0x3f3f3f3f
#define Inf 0x3FFFFFFFFFFFFFFFLL
#define eps 1e-6
#define pi acos(-1.0)
using namespace std;
typedef long long ll;
const int maxn = 50000 + 10;
struct Edge
{
int v,next;
Edge(int v = 0,int next = 0):v(v),next(next){}
}edges[maxn<<1];
int head[maxn],factor[maxn][33],pcnt,nEdge;
int sons[maxn],pa[maxn],q[maxn],flag[maxn];
ll ans,val[maxn],primes[33];
map<ll,int>mp;
void AddEdges(int u,int v)
{
edges[++nEdge] = Edge(v,head[u]);
head[u] = nEdge;
edges[++nEdge] = Edge(u,head[v]);
head[v] = nEdge;
}
int findroot(int x)
{
int tail = 0,u,v;
q[tail++] = x;
pa[x] = 0;
for(int i = 0;i < tail;++i)
{
u = q[i];
for(int k = head[u];k != -1;k = edges[k].next)
{
v = edges[k].v;
if(v == pa[u] || flag[v]) continue;
pa[v] = u;
q[tail++] = v;
}
}
int minv = inf,root = -1,mx;
for(int i = tail - 1;i >= 0;--i)
{
u = q[i];
sons[u] = 1;
mx = 0;
for(int k = head[u];k != -1;k = edges[k].next)
{
v = edges[k].v;
if(v == pa[u] || flag[v]) continue;
sons[u] += sons[v];
mx = max(mx,sons[v]);
}
mx = max(mx,tail - sons[u]);
if(mx < minv)
{
minv = mx;
root = u;
}
}
return root;
}
inline ll getmsk(int u)
{
ll res = 0;
for(ll i = 0;i < pcnt;++i)
res |= (ll)factor[u][i]<<(i<<1LL);
return res;
}
inline ll Uion(ll x,ll y)
{
ll res = 0;
for(ll i = 0;i < pcnt;i++)
{
res |= ((
((x>>(i<<1LL)) & 3) +
((y>>(i<<1LL)) & 3)
)%3)<<(i<<1LL);
}
return res;
}
inline ll getRev(ll x)
{
ll res = 0;
int v;
for(ll i = 0;i < pcnt;++i)
{
v = (x>>(i<<1LL)) & 3;
if(v == 1) res |= (2LL<<(i<<1));
else if(v == 2) res |= (1LL<<(i<<1));
}
return res;
}
void cal(int x,int fa,ll fval)
{
pa[x] = fa;
int tail = 0,u,v;
q[tail++] = x;
val[fa] = 0;
ll tmp;
for(int i = 0;i < tail;++i)
{
u = q[i];
val[u] = Uion(val[pa[u]],getmsk(u));
tmp = getRev(val[u]);
if(mp.find(tmp) != mp.end())
ans += mp[tmp];
for(int k = head[u];k != -1;k = edges[k].next)
{
v = edges[k].v;
if(v == pa[u] || flag[v]) continue;
pa[v] = u;
q[tail++] = v;
}
}
for(int i = tail - 1;i >= 0;--i)
{
u = q[i];
tmp = Uion(val[u],fval);
mp[tmp]++;
}
}
void solve(int x)
{
int root = findroot(x);
flag[root] = 1;
mp.clear();
ll rmsk = getmsk(root);
if(rmsk == 0) ans++;
mp[rmsk] = 1;
for(int k = head[root];k != -1;k = edges[k].next)
{
int v = edges[k].v;
if(flag[v]) continue;
cal(v,root,rmsk);
}
for(int k = head[root];k != -1;k = edges[k].next)
if(!flag[edges[k].v]) solve(edges[k].v);
}
int main()
{
// freopen("in.txt","r",stdin);
// freopen("out.txt","w",stdout);
int n;
while(~scanf("%d",&n))
{
memset(head,0xff,sizeof(head));
nEdge = -1;
memset(flag,0,sizeof(flag));
scanf("%d",&pcnt);
for(int i = 0;i < pcnt;++i)
scanf("%I64d",&primes[i]);
ll num;
for(int i = 1;i <= n;++i)
{
memset(factor[i],0,sizeof(factor[i]));
scanf("%I64d",&num);
for(int j = 0;j < pcnt;++j)
{
while(num % primes[j] == 0)
{
factor[i][j]++;
num /= primes[j];
}
factor[i][j] %= 3;
}
}
int u,v;
for(int i = 1;i < n;++i)
{
scanf("%d%d",&u,&v);
AddEdges(u,v);
}
ans = 0;
solve(1);
printf("%I64d\n",ans);
}
return 0;
}