题意
有 n n n个点的一棵树,每个点有一个权值,对于每一个 i i i,问 g ( x , y ) = i g(x,y)=i g(x,y)=i的 ( x , y ) (x,y) (x,y)对数 ( x ≤ y ) (x\leq y) (x≤y). g ( x , y ) g(x,y) g(x,y)表示 x x x到 y y y的路径上的所有 a i a_i ai的 g c d gcd gcd。
题解
Tip1
点分治:对每一个点统计到各个点的
g
c
d
gcd
gcd,用map存一下。
点分治自己复杂度是
O
(
n
l
o
g
n
)
O(n log n)
O(nlogn),
g
c
d
gcd
gcd、
m
a
p
map
map统计,大概有三个
l
o
g
log
log
所以总复杂度大约是
(
O
n
l
o
g
4
n
)
(O n log ^ 4n)
(Onlog4n)
#include<cstdio>
#include<map>
#include<algorithm>
#include<iostream>
using namespace std;
const int N = 3e5;
struct Edge {
int u, v, nxt;
}e[N * 2];
int head[N], en, a[N];
void addl(int x, int y) {
e[++en].u = x, e[en].v = y, e[en].nxt = head[x], head[x] = en;
}
#define fi first
#define se second
typedef long long ll;
ll ans[N];
bool vis[N];
int n;
int rt;
map<int,int> c, d;
void get(int x, int F, int g) {
++d[g];++ans[g];
for(int i = head[x]; i;i = e[i].nxt) {
int y = e[i].v;
if(y == F || vis[y]) continue;
get(y, x, __gcd(g, a[y]));
}
}
void solve(int x) {
++ans[a[x]];
c.clear();
for(int i = head[x];i; i = e[i].nxt) {
int y = e[i].v;
if(vis[y]) continue;
d.clear();
get(y, x, __gcd(a[x], a[y]));
for(auto &p : c)
for(auto &q : d) {
int t = __gcd(p.fi, q.fi);
ans[t] += (ll)p.se * q.se;
// printf("%d:++", t);
}
for(auto &q : d) c[q.fi] += q.se;
}
// for(auto &p: c)
// cout<<"DD"<<p.fi<<" "<<p.se<<endl;
}
int sum;
int siz[N], mx[N];
void getroot(int x, int F) {
siz[x] = 1;
mx[x] = 0;
for(int i = head[x]; i;i = e[i].nxt) {
int y = e[i].v;
if(y == F || vis[y]) continue;
getroot(y, x);
siz[x] += siz[y];
mx[x] = max(mx[x], siz[y]);
}
mx[x] = max(mx[x], sum - mx[x]);
if(mx[x] < mx[rt]) rt = x;
}
void dfs(int x, int F) {
vis[x] = 1; solve(x);
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].v;
if(vis[y]) continue;
sum = siz[y];
rt = 0;
getroot(y, x);
dfs(rt, x);
}
}
map<int,int> t;
int main() {
scanf("%d", &n);
for(int i = 1; i <= n; ++i) scanf("%d", &a[i]);
for(int i = 1; i < n; ++i) {
int x, y;
scanf("%d%d",&x,&y);
addl(x, y);
addl(y, x);
}
mx[0] = 1e9;
sum = n;
getroot(1, 0);
dfs(rt, 0);
for(int i = 1; i <= 2e5; ++i)
if(ans[i]) printf("%d %lld\n", i, ans[i]);
return 0;
}