三点在树上距离相等的情况只有一种,就是以某一个点为中心,三个点到这个点的距离相等。
所以直接枚举每个点作为中心,dfs这个中心的子树,根据乘法原理统计答案即可。
时间复杂度 O(n2) (n <= 5000)
代码
#include <cstdio>
#include <cstring>
#include <iostream>
#define N 5001
#define LL long long
LL ans;
int n, cnt;
int head[N], to[N << 1], next[N << 1], one[N], two[N], dis[N], tmp[N];
inline int read()
{
int x = 0, f = 1;
char ch = getchar();
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = -1;
for(; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + ch - '0';
return x * f;
}
inline void add(int x, int y)
{
to[cnt] = y;
next[cnt] = head[x];
head[x] = cnt++;
}
inline void dfs(int u)
{
int i, v;
tmp[dis[u]]++;
for(i = head[u]; i ^ -1; i = next[i])
{
v = to[i];
if(!dis[v])
{
dis[v] = dis[u] + 1;
dfs(v);
}
}
}
int main()
{
int i, j, k, x, y;
n = read();
memset(head, -1, sizeof(head));
for(i = 1; i < n; i++)
{
x = read();
y = read();
add(x, y);
add(y, x);
}
for(i = 1; i <= n; i++)
{
memset(dis, 0, sizeof(dis));
memset(one, 0, sizeof(one));
memset(two, 0, sizeof(two));
dis[i] = 1;
for(j = head[i]; j ^ -1; j = next[j])
{
memset(tmp, 0, sizeof(tmp));
dis[to[j]] = 2;
dfs(to[j]);
for(k = 1; k <= n; k++)
{
ans += (LL)two[k] * tmp[k];
two[k] += one[k] * tmp[k];
one[k] += tmp[k];
}
}
}
printf("%lld\n", ans);
return 0;
}