E - Expectations sky-high ( 求树的任意两点距离和 )
题目链接:https://vjudge.net/problem/Gym-102020E
题意:有一棵n个点的树,问任选两点 ( 不一定不同 ) 的概率长度是多长。
思路:dfs容易求出某个点到其他所有点的距离和是多少。难点在于如何用一次dfs求出所有点到其他所有点的距离和。定义一个点到其所有子节点的距离和为ans,我们先跑一次dfs求出1号节点的ans[ 1 ] ,过程肯定要用到数组sum[ u ]表示以1为根节点时,u到其子节点的距离和为多少,顺便更新一下siz[u]( 表示u的子节点数 )的大小。
重点来了,我们如何根据一个节点的ans[1] 及其相应的sum[],siz[],来处理出所有点的ans来。
fa
u x
... .... ...
上图的树中,fa是u的父亲节点,我们当前需要求u点的ans值,我们可以O(1) 来求u点的ans值 = u点到其子树的和值sum[u] + fa点到除u外其它子树的ans值, ans[fa]-sum[u]-siz[u] + u和其父亲的那条连边出现的次数(n-siz[u]).
int siz[maxn],sum[maxn],ans[maxn];
vector<int> G[maxn];
void dfs( int u, int fa ) // 求一个点的ans值,及其sum和siz数组
{
siz[u] = 1;
sum[u] = 0;
for ( int i=0; i<G[u].size(); i++ ) {
int v = G[u][i];
if ( v==fa ) continue ;
dfs(v,u);
siz[u] += siz[v]; siz[u]%=mod;
sum[u] += (sum[v] + siz[v])%mod; sum[u]%=mod;
}
}
void dfs2( int u, int fa ) // 根据一个点的ans,来O(1)求其他点的ans
{
if ( u!=1 ) {
ans[u] = (sum[u] + ( ans[fa]-sum[u]-siz[u] ) + ( n-siz[u] ))%mod; /// 重点!!画图理解
}
for ( int i=0; i<G[u].size(); i++ ) {
int v = G[u][i];
if ( v==fa ) continue ;
dfs2(v,u);
}
}
这个题是要算概率长度,我们算出总长度来除方案数就好了。
总长度 = ans数组和/2 。 除2是因为对于每一对点都算了两遍
方案数 = n*(n+1)/2 。
这样分子分母可以同乘2方便计算。
代码:
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int mod = 1e9+7;
const int maxn = 2e5+10;
vector<int> G[maxn];
int siz[maxn],sum[maxn],ans[maxn];
int n;
int qpow( int a, int n )
{
int re = 1;
while ( n ) {
if ( n&1 ) re=(re*a)%mod;
a = (a*a)%mod;
n >>= 1;
}
return re;
}
void dfs( int u, int fa )
{
siz[u] = 1;
sum[u] = 0;
for ( int i=0; i<G[u].size(); i++ ) {
int v = G[u][i];
if ( v==fa ) continue ;
dfs(v,u);
siz[u] += siz[v]; siz[u]%=mod;
sum[u] += (sum[v] + siz[v])%mod; sum[u]%=mod;
}
}
void dfs2( int u, int fa )
{
if ( u!=1 ) {
ans[u] = (sum[u] + ( ans[fa]-sum[u]-siz[u] ) + ( n-siz[u] ))%mod; /// 重点!!画图理解
}
for ( int i=0; i<G[u].size(); i++ ) {
int v = G[u][i];
if ( v==fa ) continue ;
dfs2(v,u);
}
}
signed main()
{
cin >> n;
for ( int i=0; i<n-1; i++ ) {
int u,v;scanf("%lld %lld",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1,1);
ans[1] = sum[1];
dfs2(1,1);
int tot = 0;
for ( int i=1; i<=n; i++ ) {
tot += ans[i];
tot %= mod;
}
cout << tot*qpow( n*(n+1)%mod ,mod-2)%mod << endl;
return 0;
}