题目描述
大致意思就是给你n个点n-1条边,然后依次求出每个点到其他所有点的距离和
样例
Sample Input 1
3
1 2
2 3
Sample Output 1
3
2
3
Sample Input 2
2
1 2
Sample Output 2
1
1
Sample Input 3
6
1 6
1 5
1 3
1 4
1 2
Sample Output 3
5
9
9
9
9
9
算法:推公式加dfs求解
O(n)
我们先求一个点的
∑
j
N
d
i
s
t
(
i
,
j
)
\sum_j^Ndist(i,j)
j∑Ndist(i,j)
首先dfs出1的结果,记录为dist[1];
我们考察1的儿子节点s们的dist[]值会有什么区别:
对于1的儿子s和i;
如果i是不是s的儿子结点,那么 len(1,i)=len(s,i)+1;
如果i是s的结点,那么len(1,i)=len(s,i)-1;
因此我们假设所有i都不是s的儿子,那么dist[s]需要加上n(n个点),但是我们可以求出来s的儿子们的个数son[s],因此这些被多加了,原本应该减去,因此就是减去两倍,即:
d
i
s
t
[
s
]
=
d
i
s
t
[
j
]
+
n
−
s
∗
s
o
n
[
s
]
dist[s]=dist[j]+n-s*son[s]
dist[s]=dist[j]+n−s∗son[s]
因此做法就是:
1.做一遍深搜搜出dist[1],并且记录以1为根的树的所有结点的儿子结点个数
2.再做一遍深搜,按照由近到远的顺序,先求儿子结点的答案,再递归求儿子的儿子的答案即可
C++ 代码
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
//================================
#define debug(a) cout << #a": " << a << endl;
#define N 200010
//================================
typedef pair<int,int> pii;
#define x first
#define y second
typedef long long LL; typedef unsigned long long ULL; typedef long double LD;
inline LL read() { LL s = 0, w = 1; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) if (ch == '-') w = -1; for (; isdigit(ch); ch = getchar()) s = (s << 1) + (s << 3) + (ch ^ 48); return s * w; }
inline void print(LL x, int op = 10) { if (!x) { putchar('0'); if (op) putchar(op); return; } char F[40]; LL tmp = x > 0 ? x : -x; if (x < 0)putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0)putchar(F[--cnt]); if (op) putchar(op); }
//=================================
int n;
int e[2*N],ne[2*N],h[N],idx=0;
long long son[N],dist[N];
void add(int a,int b){
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void dfs(int u,int pre,int len){
dist[1]+=len;
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
if(j==pre) continue;
dfs(j,u,len+1);
son[u]+=son[j];
}
}
void dfs(int u,int pre){
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
if(j==pre) continue;
dist[j]=dist[u]+n-2ll*son[j];
dfs(j,u);
}
}
//=================================
int main(){
memset(h,-1,sizeof h);
scanf("%d",&n);
for(int i=1;i<n;i++){
int a,b;
scanf("%d%d",&a,&b);
add(a,b),add(b,a);
}
for(int i=1;i<=n;i++) son[i]=1;
dfs(1,-1,0);
dfs(1,-1);
for(int i=1;i<=n;i++)
printf("%lld\n",dist[i]);
return 0;
}