题意
给出一棵n个点,n-1条边的树。现在计算所有标号为x到y的距离之和(满足y>x且y是x的倍数)
思路
关于树上任意两点距离之和,一开始想到树形dp,可树形dp,是对每条边,求所有可能的路径经过此边的次数,是求出边两端的点数,这条边被经过的次数就是两端点数的乘积。
但是该题对计算的距离加了限制(y>x且y是x的倍数),显然不能用树形dp来做了。
接下来想到图论部分的算法,想处理出来两点之间的距离,也就是最短路,但是n=2e5,跑n遍dijkstra或者n遍spfa在复杂度上肯定会T,故不可行。
队友提到LCA(最近公共祖先),随意定义一个根root,用LCA倍增法预处理出来每个点到根的深度数组dep[],然后找两个点的最近公共祖先lca,那么这两点间的距离即 dep[i] + dep[i*j] - dep[lca] * 2 + 1
(根据题意,距离是这两点间边的个数+1)
代码是kuangbin的LCA倍增法板子改了一波。
AC代码
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <iostream>
#include <vector>
#include <queue>
using namespace std;
const int maxn = 2e5+5;
const int DEG = 20;
typedef long long ll;
struct Edge{
int to, next;
}edge[maxn*2];
int tot, head[maxn];
void init() {
tot = 0;
memset(head, -1, sizeof head);
}
void addedge(int u, int v) {
edge[tot].to = v;
edge[tot].next = head[u];
head[u] = tot++;
}
int fa[maxn][20];
int dep[maxn];
void BFS(int root) {
queue<int> que;
dep[root] = 0;
fa[root][0] = root;
que.push(root);
while(!que.empty()) {
int tmp = que.front();
que.pop();
for(int i = 1; i < DEG; i++)
fa[tmp][i] = fa[fa[tmp][i-1]][i-1];
for(int i = head[tmp]; i != -1; i = edge[i].next) {
int v = edge[i].to;
if(v==fa[tmp][0]) continue;
dep[v] = dep[tmp] + 1;
fa[v][0] = tmp;
que.push(v);
}
}
}
int LCA(int u, int v) {
if(dep[u] > dep[v]) swap(u, v);
int hu = dep[u], hv = dep[v];
int tu = u, tv = v;
for(int det = hv-hu, i = 0; det; det >>= 1, i++) {
if(det & 1) {
tv = fa[tv][i];
}
}
if(tu == tv) {
return tu;
}
for(int i = DEG-1; i >= 0; i--) {
if(fa[tu][i] == fa[tv][i]) {
continue;
}
tu = fa[tu][i];
tv = fa[tv][i];
}
return fa[tu][0];
}
bool flag[maxn];
int main() {
int n;
int u, v;
int root;
scanf("%d", &n);
init();
memset(flag, 0, sizeof flag);
for(int i = 0; i < n-1; i++) {
scanf("%d%d", &u, &v);
addedge(u, v);
addedge(v, u);
flag[v] = true;
}
for(int i = 1; i <= n; i++) {
if(!flag[i]) {
root = i;
break;
}
}
BFS(root);
ll sum = 0;
int lca;
for(int i = 1; i <= n; i++) {
for(int j = 2; i*j <= n; j++) {
lca = LCA(i, i*j);
sum += dep[i] + dep[i*j] - dep[lca]*2 + 1;
}
}
printf("%lld\n", sum);
return 0;
}