合约数
题意
有一棵 n n n 个点 n − 1 n-1 n−1 条边的带权树,每个点的权值是 a [ i ] a[i] a[i] ,定义 F ( i ) F(i) F(i) 为以 i i i 为根的所有子树中,点的权值是合数且是 a [ i ] a[i] a[i] 的约数的个数。求 ∑ i = 1 n i ⋅ F ( i ) \sum_{i=1}^{n} i\cdot F(i) ∑i=1ni⋅F(i) 。结果取模 1 0 9 + 7 10^9+7 109+7 。
解法
可离线的子树问题,考虑树上启发式合并。记录子树每个权值出现的次数,然后算 a [ u ] a[u] a[u] 的每个合约数的贡献就可以了。由于 a [ i ] ≤ 1 0 4 a[i]\le 10^4 a[i]≤104 ,所以可以先预处理出任意一个数的所有合约数。暴力枚举所有合约数,对所有合约数在子树中出现次数加和即可,就得到了 F ( i ) F(i) F(i) 。然后累加答案即可。
代码
#pragma region
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <vector>
using namespace std;
typedef long long ll;
#define rep(i, a, n) for (int i = a; i <= n; ++i)
#pragma endregion
const int maxn = 2e4 + 5;
const ll mod = 1000000007;
int n, r, a[maxn];
int sz[maxn], son[maxn];
vector<int> g[maxn], v[maxn];
int ans[maxn], cnt[maxn], flag;
bool pri[maxn];
void findP() {
for (int i = 2; i <= 10000; ++i) {
if (pri[i]) {
int j;
for (j = 1; j * j < i; ++j)
if (i % j == 0) {
if (pri[j]) v[i].push_back(j);
if (pri[i / j]) v[i].push_back(i / j);
}
if (j * j == i && pri[j]) v[i].push_back(j);
}
for (int j = i * 2; j <= 10000; j += i)
pri[j] = i;
}
}
int cul(int u) {
int ans = 0;
for (auto i : v[a[u]]) ans += cnt[i];
return ans;
}
void count(int u, int f, int val) {
cnt[a[u]] += val;
for (auto v : g[u]) {
if (v == f || v == flag) continue;
count(v, u, val);
}
}
void dfs1(int u, int f) {
sz[u] = 1;
for (auto v : g[u]) {
if (v == f) continue;
dfs1(v, u);
sz[u] += sz[v];
if (sz[v] > sz[son[u]]) son[u] = v;
}
}
void dfs(int u, int f, bool keep) {
for (auto v : g[u]) {
if (v == f || v == son[u]) continue;
dfs(v, u, 0);
}
if (son[u]) {
dfs(son[u], u, 1);
flag = son[u];
}
count(u, f, 1);
flag = 0;
ans[u] = cul(u);
if (!keep) {
count(u, f, -1);
}
}
int main() {
int T;
scanf("%d", &T);
findP();
while (T--) {
scanf("%d%d", &n, &r);
rep(i, 1, n) g[i].clear(), son[i] = 0;
rep(i, 1, n - 1) {
int u, v;
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
rep(i, 1, n) scanf("%d", &a[i]);
dfs1(r, 0);
dfs(r, 0, 0);
ll res = 0;
rep(i, 1, n) res = (res + i * ans[i]) % mod;
printf("%lld\n", res);
}
}