题目:http://acm.hdu.edu.cn/showproblem.php?pid=4670
题目大意:给你一棵有n个顶点的树,每个节点有一个权值,给你k个prime,每个权值都可以由这k个prime的幂次方的和组成,问你在树上有多少条路径,使这条路径上的点的权值积是一个立方数。
思路:一个数是立方数,当且只当它的拆分成的所有质因子都是幂都是3的倍数,而质因子的4、7次幂和1次幂一样,5次、8次和2次一样,则对于幂次的个数,我们都对3取余,结果不变。每次选取这棵树的重心,然后算出所有经过这个点的路径数(暴搜),最后全部加起来即可。一个儿子,一个儿子搜过来,如果当前这个儿子这条路径的各因数幂次都知道了,那么我们只要加上之前有的它的互补的幂次的个数就好。互补就是说,加起来都是3的倍数,也就是全是0,比如:当前路径为0 1 2,那么只要看看 0 2 1这条路径之前出现的个数,加起来就行了。还有,这里的互补不能用 2 2 2 的满状态去剪应该是 3 3 3 的,比如:1 1 1,那么一剪还是1 1 1,其实是2 2 2。这里还有一个地方特别需要注意:k最大为30,可以用lld存下来表示状态,但是用数组哈希显然是不行的,用map,清零和哈希都很方便。
其实挺简单的说,功力太差,有个地方一直搞来搞去,搜的那条路径不包括根节点,然后更新 hash 的时候要加上根节点,初始化就是只包含根节点的状态为1。调了一个下午,挫了。。= =
搓代码一份,如下:
#pragma comment(linker, "/STACK:10240000000000,10240000000000")
#include<cstdio>
#include<cstring>
#include<map>
#include<vector>
#include<algorithm>
using namespace std;
typedef __int64 lld;
const int MAXN = 55555 ;
int n,k;
struct Edge
{
int next,t;
} edge[MAXN<<1];
int tot ,head[MAXN];
void add_edge(int s,int t)
{
edge[tot].t = t;
edge[tot].next = head[s];
head[s] = tot++;
}
struct Node
{
int cnt[33];
} node[MAXN];
int num[MAXN],maxv[MAXN];
int vis[MAXN];
void get_size(int u,int fa)
{
num[u] = 1;
maxv[u] = 0;
for(int e = head[u];e!=-1;e = edge[e].next)
{
int v = edge[e].t;
if(vis[v]||v==fa) continue;
get_size(v,u);
num[u] += num[v];
maxv[u] = max(maxv[u],num[v]);
}
}
int minn ;
void find_root(int u,int fa,int &root,int sum)
{
int tmp = max(sum - num[u],maxv[u]);
if(tmp < minn )
{
minn = tmp;
root = u;
}
for(int e = head[u];e!=-1;e=edge[e].next)
{
int v = edge[e].t;
if(vis[v]||fa==v) continue;
find_root(v,u,root,sum);
}
}
int get_root(int u)
{
get_size(u,-1);
int sum = num[u];
minn = n;
int root = u;
find_root(u,-1,root,sum);
return root;
}
lld exp[33];
void init()
{
exp[0] = 1;
for(int i = 1;i<=30;i++)
exp[i] = exp[i-1] * 3;
}
map <lld,int> sta;
int ss[33];
int ret;
vector <lld> vec;
void dfs(int u,int fa,int root)
{
for(int i = 0;i<k;i++)
ss[i] = (ss[i] + node[u].cnt[i])%3;
lld cc = 0,cc2 = 0;
for(int i = 0;i<k;i++)
{
cc += (3 - ss[i])%3*exp[i];
cc2 += (ss[i]+node[root].cnt[i])%3*exp[i];
}
vec.push_back(cc2);
ret += sta[cc];
for(int e = head[u];e!=-1;e = edge[e].next)
{
int v = edge[e].t;
if(vis[v]||v==fa) continue;
dfs(v,u,root);
for(int i = 0;i<k;i++)
ss[i] = (ss[i] - node[v].cnt[i] + 3)%3;
}
}
int count(int u)
{
ret = 0 ;
sta.clear();
lld cc = 0;
for(int i = 0;i<k;i++)
{
cc += node[u].cnt[i]*exp[i];
}
sta[cc] = 1;
if(cc == 0) ret = 1;
for(int e = head[u] ; e!=-1;e = edge[e].next)
{
int v = edge[e].t;
if(vis[v]) continue;
memset(ss,0,sizeof(ss));
vec.clear();
dfs(v,u,u);
for(int i = 0;i<vec.size();i++)
sta[vec[i]] ++ ;
}
//printf("ret = %d\n",ret);
return ret;
}
int ans ;
void solve(int u)
{
int root = get_root(u);
//printf("root = %d\n",root);
vis[root] = 1;
ans += count(root);
for(int e = head[root] ; e!=-1 ;e = edge[e].next)
{
int v = edge[e].t;
if(vis[v]) continue;
solve(v);
}
}
lld pri[33];
int main()
{
init();
while(~scanf("%d",&n))
{
scanf("%d",&k);
for(int i = 0;i<k;i++)
scanf("%I64d",&pri[i]);
for(int i = 0;i<n;i++)
{
lld tmp;
scanf("%I64d",&tmp);
memset(node[i].cnt,0,sizeof(node[i].cnt));
for(int j = 0;j<k;j++)
{
while(tmp&&(tmp%pri[j]==0))
{
node[i].cnt[j] ++;
node[i].cnt[j] = node[i].cnt[j]%3;
tmp = tmp/pri[j];
}
if( tmp == 0 )
break;
}
}
tot=0;
memset(head,-1,sizeof(head));
int a,b;
for(int i = 1;i<n;i++)
{
scanf("%d%d",&a,&b);
a--;
b--;
add_edge(a,b);
add_edge(b,a);
}
memset(vis,0,sizeof(vis));
ans = 0;
solve(0);
printf("%d\n",ans);
}
return 0;
}
/*
6
2 2 3
36 36 36 36 36 36
1 2
2 3
3 4
4 5
5 6
*/