1.问题描述:
给出一个有向无环的连通图,起点为 1,终点为 N,每条边都有一个长度。数据保证从起点出发能够到达图中所有的点,图中所有的点也都能够到达终点。绿豆蛙从起点出发,走向终点。到达每一个顶点时,如果有 K 条离开该点的道路,绿豆蛙可以选择任意一条道路离开该点,并且走向每条路的概率为 1/K。现在绿豆蛙想知道,从起点走到终点所经过的路径总长度的期望是多少?
输入格式
第一行: 两个整数 N,M,代表图中有 N 个点、M 条边。第二行到第 1+M 行: 每行 3 个整数 a,b,c,代表从 a 到 b 有一条长度为 c 的有向边。
输出格式
输出从起点到终点路径总长度的期望值,结果四舍五入保留两位小数。
数据范围
1 ≤ N ≤ 10 ^ 5,
1 ≤ M ≤ 2N
输入样例:
4 4
1 2 1
1 3 2
2 3 3
3 4 4
输出样例:
7.00
来源:https://www.acwing.com/problem/content/219/
2. 思路分析:
对于概率的问题有一个性质:所有的概率问题能使用dp来解决的一个理论基础为:
这个式子经常用来路径长度的期望,一般起点是唯一的,终点不唯一,我们可以从终点往起点开始逆推,其中f(N) = 0,res = f(1),使用记忆化搜索就可以实现从终点一直逆推到起点,我们可以使用一个dout数组来统计每一个点的出度,这样后面在递归的时候可以计算出当前点往下递归的时候1 / k的值。
递推式如下图所示,使用记忆化搜索即可:
3. 代码如下:
python(超时):
from typing import List
import sys
import collections
class Solution:
# 记忆化搜索
def dp(self, u: int, f: List[float], dout: List[int], g: List[List[int]]):
# 当发现之前搜索过了那么直接返回即可
if f[u] >= 0: return f[u]
f[u] = 0
for next in g[u]:
f[u] += (next[1] + self.dp(next[0], f, dout, g)) / dout[u]
return f[u]
def process(self):
n, m = map(int, input().split())
# 使用列表来构建无向图
g = [list() for i in range(n + 10)]
# 计算点的出度
dout = collections.defaultdict(int)
for i in range(m):
a, b, c = map(int, input().split())
g[a].append((b, c))
# 统计节点的出度, 方便后面计算1 / k
dout[a] += 1
f = [-1.00] * (n + 10)
# 保留两位有效数字
return "{:.2f}".format(self.dp(1, f, dout, g))
if __name__ == "__main__":
# 设置递归的最大深度
sys.setrecursionlimit(10000)
print(Solution().process())
c++:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 100010, M = 200010;
int n, m;
int h[N], e[M], w[M], ne[M], idx;
int dout[N];
double f[N];
void add(int a, int b, int c)
{
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}
double dp(int u)
{
if (f[u] >= 0) return f[u];
f[u] = 0;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
f[u] += (w[i] + dp(j)) / dout[u];
}
return f[u];
}
int main()
{
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h);
for (int i = 0; i < m; i ++ )
{
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c);
dout[a] ++ ;
}
memset(f, -1, sizeof f);
printf("%.2lf\n", dp(1));
return 0;
}