参考博客:https://blog.csdn.net/qq_45458915/article/details/109294160
分割内容参考自学长博客
题目大意:给出一棵带权树,规定 ,解释一下就是当确定三个点 u1 , u2 , u3 后,需要找到一个点 v 到三个点的距离之和最小,现在给出 u1 , u2 , u3 的可行取值,问 f 函数的期望是多少
题目分析:考虑转换模型,对于给定的 u1 , u2 和 u3 来说,不难猜出点 v 是唯一存在的(不会证明),相应的这个最短的距离之和也是唯一确定的,且可以表示为
这样一来根据两个期望的基本公式进行转换:
- E( X + Y ) = E( X ) + E( Y )
- E( CX ) = CE( X )
如此一来就将 u1 , u2 , u3 的贡献拆成了分别独立的三组,再考虑对于 E( dis( u , v ) ) 该如何去求
现在问题就是如何快速求出 E( dis( u1 , u2 ) ) 了
接下来一个思维点就是,需要想到计算每条边的贡献,具体就是,对于一条边 ( u , v ) 来说,当移除掉这条边后,整棵树将会被分成不连通的两个部分,记为 T1 和 T2,比较显然的是:
- T1 中的 u1 到 T2 中的 u2 必然会经过当前边
- T1 中的 u2 到 T2 中的 u1 必然会经过当前边
直接树形 dp 就好了
学长最后的那部分具体内容是这样的。这部分的原型是一个求树上任意两点之间距离和的问题。
暴力的话用lca也要n^2。所以这部分是一个经典的树上dp问题。
给出一个原型题:http://acm.hdu.edu.cn/showproblem.php?pid=2376
思路:
建完树后,预处理出来每个节点的size[],然后转化成求每条边对答案的贡献。
考虑单条边对答案的贡献:比如(u,v)这条边,v的siz[v]里的每一个,和n-siz[v]的剩下中的每一个,之间的道路都会用到这条边,所以可以用dp来处理。
即(u,v)边对答案的贡献是siz[v]*(n-siz[v])*w;
#include<iostream>
#include<vector>
#include<queue>
#include<cstring>
#include<cmath>
#include<map>
#include<set>
#include<cstdio>
#include<algorithm>
#define debug(a) cout<<#a<<"="<<a<<endl;
using namespace std;
const int maxn=1e4+100;
typedef int LL;
typedef pair<LL,LL>P;///first表示编号,second表示权重
vector<P>g[maxn];
LL siz[maxn];
double dp[maxn];
LL n;
void dfs(LL u,LL fa)
{
siz[u]=1;
for(LL i=0;i<g[u].size();i++){
P p=g[u][i];
LL v=p.first;LL w=p.second;
if(v==fa) continue;
dfs(v,u);
siz[u]+=siz[v];
dp[u]+=(dp[v]+(siz[v]*(n-siz[v])*(double)w));
}
}
int main(void)
{
cin.tie(0);std::ios::sync_with_stdio(false);
LL t;cin>>t;
while(t--)
{
cin>>n;
for(LL i=0;i<n+10;i++) g[i].clear(),siz[i]=0,dp[i]=0;
for(LL i=1;i<=n-1;i++)
{
LL u,v,w;cin>>u>>v>>w;
g[u].push_back({v,w});
g[v].push_back({u,w});
}
dfs(0,-1);
LL sum=0;
for(LL i=0;i<n;i++) sum+=i;
printf("%.6f\n",1.0*dp[0]/sum);
}
return 0;
}
然后回到这道题,上面那道题是每个节点都可以相互到达,这题要求只统计块中节点的相互到达。也就是考虑两个块中,(u,v)的边对最终答案有什么贡献。
处理办法是,还是用siz,这时候多加一维,表示每个块编号。
比如siz[A][v]就表示,以v为根的子树中,是A类型的节点有多少个。
然后这样处理出来之后,再像上一题一样去处理。
#include<iostream>
#include<vector>
#include<queue>
#include<cstring>
#include<cmath>
#include<map>
#include<set>
#include<cstdio>
#include<algorithm>
#define debug(a) cout<<#a<<"="<<a<<endl;
using namespace std;
const int maxn=2e5+100;
typedef long long LL;
typedef pair<LL,LL>P;///first表示编号,second表示距离
LL siz[4][maxn],cnt[4];
vector<P>g[maxn];
double ans=0;
void dfs1(LL u,LL fa)
{
for(LL it=0;it<g[u].size();it++)
{
P p=g[u][it];
LL v=p.first;LL w=p.second;
if(v==fa) continue;
dfs1(v,u);
for(LL i=1;i<=3;i++){
siz[i][u]+=siz[i][v];
}
}
}
void dfs2(LL u,LL fa)
{
for(LL it=0;it<g[u].size();it++){
P p=g[u][it];
LL v=p.first;LL w=p.second;
if(v==fa) continue;
dfs2(v,u);
for(LL i=1;i<=3;i++){
for(LL j=1;j<=3;j++)
{
if(i==j) continue;
ans+=1.0*((siz[i][1]-siz[i][v])*(siz[j][v])*1.0*w/cnt[i]/cnt[j]/2.0);
}
}
}
}
int main(void)
{
cin.tie(0);std::ios::sync_with_stdio(false);
LL n;cin>>n;
for(LL i=1;i<n;i++){
LL u,v,w;cin>>u>>v>>w;
g[u].push_back({v,w});
g[v].push_back({u,w});
}
for(LL i=1;i<=3;i++){
cin>>cnt[i];
for(LL j=1;j<=cnt[i];j++)
{
LL x;cin>>x;
siz[i][x]++;
}
}
dfs1(1,-1);
dfs2(1,-1);
printf("%.10f\n",ans);
return 0;
}