链接:
https://www.nowcoder.com/acm/contest/91/B
来源:牛客网
来源:牛客网
题目描述
在埃森哲,员工培训是最看重的内容,最近一年,我们投入了 9.41 亿美元用于员工培训和职业发展。截至 2018 财年末,我们会在全球范围内设立 100 所互联课堂,将互动科技与创新内容有机结合起来。按岗培训,按需定制,随时随地,本土化,区域化,虚拟化的培训会让你快速取得成长。小埃希望能通过培训学习更多ACM 相关的知识,他在培训中碰到了这样一个问题,
给定一棵
n个节点的树,并且根节点的编号为
p,第
i个节点有属性值
vali, 定义
F(i): 在以
i为根的子树中,属性值是
vali的合约数的节点个数。y 是 x 的合约数是指 y 是合数且 y 是 x 的约数。小埃想知道
对
1000000007取模后的结果
.
![](https://i-blog.csdnimg.cn/blog_migrate/26e0a980708efae5182b0a8d11d6da96.png)
输入描述:
输入测试组数T,每组数据,输入n+1行整数,第一行为n和p,1<=n<=20000, 1<=p<=n, 接下来n-1行,每行两个整数u和v,表示u和v之间有一条边。第n+1行输入n个整数val1, val2,…, valn,其中1<=vali<=10000,1<=i<=n.
输出描述:
对于每组数据,输出一行,包含1个整数, 表示对1000000007取模后的结果
示例1
输入
2 5 4 5 3 2 5 4 2 1 3 10 4 3 10 5 3 3 1 3 2 1 1 10 1
输出
11 2
备注:
n>=10000的有20组测试数据
题解:
首先预处理1-N里的每个数的合约数,复杂度O(N*sqrt(N)), 然后根据DFS序记录目前为止每个数出现的次数,第一次到这个节点i时候先F[i]先减去他之前合约数出现的总次数,再次回溯的时候再次加上他合数出现的次数,两次得到的即为F[i].
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 20100;
vector<int>g[N];
int val[N];
bool isprime[N];
vector<short>f[N];
int cnt[N], F[N];
ll res;
const ll MOD = 1E9 + 7;
void init()
{
for(int i = 2;i < N;i ++) isprime[i] = 1;
for(int i = 2;i < N;i ++) {
if(isprime[i]) {
for(int j = i + i;j < N;j += i) {
isprime[j] = 0;
}
}
}
for(int i = 4;i <= 10000;i ++) {
if(!isprime[i]) {
for(int j = i;j <= 10000;j += i) {
f[j].push_back(i);
}
}
}
}
void dfs(int u, int par)
{
for(int i = 0;i < f[val[u]].size();i ++) {
int num = f[val[u]][i];
F[u] -= cnt[num]; //根据dfs序,先减去他之前出现的
}
cnt[val[u]] ++;
for(int i = 0;i < g[u].size();i ++) {
int v = g[u][i];
if(v == par) continue;
dfs(v, u);
}
for(int i = 0;i < f[val[u]].size();i ++) {
int num = f[val[u]][i];
F[u] += cnt[num]; //回溯到他时,+现在出现过的总数 = 遍历完其子树后的总合约数个数-之前出现的合约数个数.
}
res += F[u] * 1LL * u;
res %= MOD;
}
int main()
{
int T;
init();
scanf("%d", &T);
while(T --) {
memset(cnt, 0, sizeof(cnt));
res = 0;
int n, p;
scanf("%d %d", &n, &p);
for(int i = 1;i <= n;i ++) F[i] = 0,g[i].clear();
for(int i = 1;i < n;i ++) {
int u, v;
scanf("%d %d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
for(int i = 1;i <= n;i ++) {
scanf("%d", &val[i]);
}
dfs(p,-1);
printf("%lld\n", res);
}
return 0;
}