You are in a city that consists of n intersections numbered from 0 to n - 1 with bi-directional roads between some intersections. The inputs are generated such that you can reach any intersection from any other intersection and that there is at most one road between any two intersections.
You are given an integer n and a 2D integer array roads where roads[i] = [ui, vi, timei] means that there is a road between intersections ui and vi that takes timei minutes to travel. You want to know in how many ways you can travel from intersection 0 to intersection n - 1 in the shortest amount of time.
Return the number of ways you can arrive at your destination in the shortest amount of time. Since the answer may be large, return it modulo 109 + 7.
Example 1:
Input: n = 7, roads = [[0,6,7],[0,1,2],[1,2,3],[1,3,3],[6,3,3],[3,5,1],[6,5,1],[2,5,1],[0,4,5],[4,6,2]]
Output: 4
Explanation: The shortest amount of time it takes to go from intersection 0 to intersection 6 is 7 minutes.
The four ways to get there in 7 minutes are:
- 0 ➝ 6
- 0 ➝ 4 ➝ 6
- 0 ➝ 1 ➝ 2 ➝ 5 ➝ 6
- 0 ➝ 1 ➝ 3 ➝ 5 ➝ 6
Example 2:
Input: n = 2, roads = [[1,0,10]]
Output: 1
Explanation: There is only one way to go from intersection 0 to intersection 1, and it takes 10 minutes.
Constraints:
1 <= n <= 200
n - 1 <= roads.length <= n * (n - 1) / 2
roads[i].length == 3
0 <= ui, vi <= n - 1
1 <= timei <= 109
ui != vi
There is at most one road connecting any two intersections.
You can reach any intersection from any other intersection.
Dijstrala with dups
Simple right code for Dijstrala with dups
class Solution:
def countPaths(self, n: int, roads: List[List[int]]) -> int:
nxt_dic = defaultdict(list)
MOD = 1000000007
for road in roads:
nxt_dic[road[0]].append((road[1],road[2]))
nxt_dic[road[1]].append((road[0],road[2]))
pq = [(0, 0)]
dp = [0]*n
dp[0] = 1
min_dis_dic = defaultdict(lambda:float('inf'))
min_dis_dic[0] = 0
while (pq):
cur_dis, cur_node = heapq.heappop(pq)
if cur_dis <= min_dis_dic[cur_node]:
for nxt_node,cost in nxt_dic[cur_node]:
if cur_dis+cost < min_dis_dic[nxt_node]:
min_dis_dic[nxt_node] = cur_dis+cost
heapq.heappush(pq, (cur_dis+cost,nxt_node))
dp[nxt_node] = dp[cur_node]
elif cur_dis+cost == min_dis_dic[nxt_node]:
dp[nxt_node] = (dp[nxt_node] + dp[cur_node]) % MOD
return dp[n-1]
Right Code for Dijstrala with dups
Use dict instead of defaultdict to compare the right and the wrong
import heapq
from typing import List
from collections import defaultdict
class Solution:
def countPaths(self, n: int, roads: List[List[int]]) -> int:
nxt_dic = defaultdict(list)
MOD = 1000000007
for road in roads:
nxt_dic[road[0]].append((road[1],road[2]))
nxt_dic[road[1]].append((road[0],road[2]))
pq = [(0, 0)]
dp = [0]*n
dp[0] = 1
min_dis_dic = {0:0}
while (pq):
cur_dis, cur_node = heapq.heappop(pq)
if cur_node not in min_dis_dic or cur_dis <= min_dis_dic[cur_node]:
for nxt_node,cost in nxt_dic[cur_node]:
if nxt_node not in min_dis_dic or cur_dis+cost < min_dis_dic[nxt_node]:
min_dis_dic[nxt_node] = cur_dis+cost
heapq.heappush(pq, (cur_dis+cost,nxt_node))
dp[nxt_node] = dp[cur_node]
elif cur_dis+cost == min_dis_dic[nxt_node]:
dp[nxt_node] = (dp[nxt_node] + dp[cur_node]) % MOD
return dp[n-1]
s = Solution()
print(s.countPaths(n = 7, roads = [[0,6,7],[0,1,2],[1,2,3],[1,3,3],[6,3,3],[3,5,1],[6,5,1],[2,5,1],[0,4,5],[4,6,2]]))
# print(s.countPaths(n = 12, roads = [[1,0,2348],[2,1,2852],[2,0,5200],[3,1,12480],[2,3,9628],[4,3,7367],[4,0,22195],[5,4,5668],[1,5,25515],[0,5,27863],[6,5,836],[6,0,28699],[2,6,23499],[6,3,13871],[1,6,26351],[5,7,6229],[2,7,28892],[1,7,31744],[3,7,19264],[6,7,5393],[2,8,31998],[8,7,3106],[3,8,22370],[8,4,15003],[8,6,8499],[8,5,9335],[8,9,5258],[9,2,37256],[3,9,27628],[7,9,8364],[1,9,40108],[9,5,14593],[2,10,45922],[5,10,23259],[9,10,8666],[10,0,51122],[10,3,36294],[10,4,28927],[11,4,25190],[11,9,4929],[11,8,10187],[11,6,18686],[2,11,42185],[11,3,32557],[1,11,45037]]))
Wrong Code for Dijstrala with dups
import heapq
from typing import List
from collections import defaultdict
class Solution:
def countPaths(self, n: int, roads: List[List[int]]) -> int:
nxt_dic = defaultdict(list)
MOD = 1000000007
for road in roads:
nxt_dic[road[0]].append((road[1],road[2]))
nxt_dic[road[1]].append((road[0],road[2]))
pq = [(0, 0)]
dp = [0]*n
dp[0] = 1
min_dis_dic = {0:0}
while (pq):
cur_dis, cur_node = heapq.heappop(pq)
if cur_node not in min_dis_dic or cur_dis <= min_dis_dic[cur_node]:
min_dis_dic[cur_node] = cur_dis
for nxt_node,cost in nxt_dic[cur_node]:
if nxt_node not in min_dis_dic or cur_dis+cost < min_dis_dic[nxt_node]:
heapq.heappush(pq, (cur_dis+cost,nxt_node))
dp[nxt_node] = dp[cur_node]
elif cur_dis+cost == min_dis_dic[nxt_node]:
dp[nxt_node] = (dp[nxt_node] + dp[cur_node]) % MOD
return dp[n-1]
s = Solution()
print(s.countPaths(n = 7, roads = [[0,6,7],[0,1,2],[1,2,3],[1,3,3],[6,3,3],[3,5,1],[6,5,1],[2,5,1],[0,4,5],[4,6,2]]))
# print(s.countPaths(n = 12, roads = [[1,0,2348],[2,1,2852],[2,0,5200],[3,1,12480],[2,3,9628],[4,3,7367],[4,0,22195],[5,4,5668],[1,5,25515],[0,5,27863],[6,5,836],[6,0,28699],[2,6,23499],[6,3,13871],[1,6,26351],[5,7,6229],[2,7,28892],[1,7,31744],[3,7,19264],[6,7,5393],[2,8,31998],[8,7,3106],[3,8,22370],[8,4,15003],[8,6,8499],[8,5,9335],[8,9,5258],[9,2,37256],[3,9,27628],[7,9,8364],[1,9,40108],[9,5,14593],[2,10,45922],[5,10,23259],[9,10,8666],[10,0,51122],[10,3,36294],[10,4,28927],[11,4,25190],[11,9,4929],[11,8,10187],[11,6,18686],[2,11,42185],[11,3,32557],[1,11,45037]]))
Diff between right and wrong
错误原因在于: 假设节点A的最短距离为5,若堆中存在多个距离为5的A节点实例,弹出后只处理第一次,后续实例会被跳过,从而漏掉路径计数。
BFS tries all path compression = Dijstrala
使用BFS对所有路径压缩进行枚举后是可以得到和Dijstrala一样效果的。但是BFS中间状态不是最短路生成树,如果边BSF边DP就错了,这是坑一;坑二是BFS得到的是prev_dict,这个prev_dict直接DP也会重复的问题,需要得到拓扑排序的结果才能DP,看一个使用DP正确的code
Right code for BSF
非常麻烦
class Solution:
def countPaths(self, n: int, roads: List[List[int]]) -> int:
def count_from_pre(prev_dic, n):
# 错误原因:例如存在0->1->2->3、0->1->3、0->2->3、0->3,这样4条路径,按层遍历结果变成了6
# 纠正方法:需要拓扑排序一下
dp = [0] * n
dp[0] = 1
reverse_seq = []
out_degree = [0]*n
for node, prev_set in prev_dic.items():
for prev_node in prev_set:
out_degree[prev_node] += 1
out_degree_0 = list(filter(lambda i:out_degree[i]==0, range(n)))
# print(out_degree)
while (out_degree_0):
n0 = out_degree_0.pop()
reverse_seq.append(n0)
for prev_node in prev_dic[n0]:
out_degree[prev_node] -= 1
if out_degree[prev_node] == 0:
out_degree_0.append(prev_node)
# print(reverse_seq)
for node in reverse_seq[::-1]:
for nxt_node, prev_set in prev_dic.items():
if node in prev_set:
dp[nxt_node] = (dp[node]+dp[nxt_node]) % 1000000007
return dp[n-1]
nxt_dic = defaultdict(list)
prev_dic = defaultdict(set)
for road in roads:
nxt_dic[road[0]].append((road[1],road[2]))
nxt_dic[road[1]].append((road[0],road[2]))
lc, ln = 0,1
layers = [[0],[]]
dis = [float('inf')]*n
dis[0] = 0
res = 0
dp = [0]*n
dp[n-1] = 1
while (layers[lc]):
for cur_node in layers[lc]:
for nxt_node, cost in nxt_dic[cur_node]:
if dis[cur_node]+cost < dis[nxt_node]:
# 按照bfs遍历的话,这一步不是最优的生成树
dis[nxt_node] = dis[cur_node]+cost
layers[ln].append(nxt_node)
prev_dic[nxt_node] = set({cur_node})
elif dis[cur_node]+cost == dis[nxt_node]:
prev_dic[nxt_node].add(cur_node)
layers[lc].clear()
lc, ln = ln, lc
# 得到最优的生成树后,需要再遍历一遍
res = count_from_pre(prev_dic, n)
return res
为什么直接用prev_dict做DP是错的
from collections import defaultdict
def count_from_pre_wrong(prev_dic, n):
# 错误原因:例如存在0->1->2->3、0->1->3、0->2->3、0->3,这样4条路径,按层遍历结果变成了6
# 纠正方法:需要拓扑排序一下
nxt_dic = defaultdict(set)
lc,ln = 0,1
layers = [set({n-1}),set()]
dp = [0] * n
dp[0] = 1
while (layers[lc]):
for cur_node in layers[lc]:
for prev_node in prev_dic[cur_node]:
nxt_dic[prev_node].add(cur_node)
layers[ln].add(prev_node)
layers[lc].clear()
lc, ln = ln, lc
print('nxt_dic={}'.format(nxt_dic))
layers = [set({0}),set()]
vis=set({0})
lc, ln = 0, 1
while (layers[lc]):
print('layers[lc]={}'.format(layers[lc]))
for cur_node in layers[lc]:
for nxt_node in nxt_dic[cur_node]:
dp[nxt_node] = dp[nxt_node] + dp[cur_node]
layers[ln].add(nxt_node)
layers[lc].clear()
lc, ln = ln, lc
return dp[n-1]
def count_from_pre(prev_dic, n):
# 错误原因:例如存在0->1->2->3、0->1->3、0->2->3、0->3,这样3条路径,按层遍历结果变成了6
# 纠正方法:需要拓扑排序一下
dp = [0] * n
dp[0] = 1
reverse_seq = []
out_degree = [0]*n
for node, prev_set in prev_dic.items():
for prev_node in prev_set:
out_degree[prev_node] += 1
out_degree_0 = [n-1]
while (out_degree_0):
n0 = out_degree_0.pop()
reverse_seq.append(n0)
for prev_node in prev_dic[n0]:
out_degree[prev_node] -= 1
if out_degree[prev_node] == 0:
out_degree_0.append(prev_node)
print(reverse_seq)
for node in reverse_seq[::-1]:
for nxt_node, prev_set in prev_dic.items():
if node in prev_set:
dp[nxt_node] += dp[node]
return dp[n-1]
prev_dic = {3:set({0,1,2}),2:set({0,1}),1:set({0}),0:set()}
print(count_from_pre_wrong(prev_dic=prev_dic, n=4))
prev_dic = {3:set({0,1,2}),2:set({0,1}),1:set({0}),0:set()}
print(count_from_pre(prev_dic=prev_dic, n=4))