这题还是有点意思的,算是个杂交题。先用dijkstra求出题目要求的单源最短路径,再用记忆化搜索优化dfs。这里dfs是可以用记忆化搜索优化的,将dfs的每一个状态含有的解计算好后放入dp数组,这样下次再遇到这个状态就可以直接返回。
from heapq import *
from collections import defaultdict
class Solution(object):
def countRestrictedPaths(self, n, edges):
"""
:type n: int
:type edges: List[List[int]]
:rtype: int
"""
class Node:
def __init__(self, node, dist):
self.node = node
self.dist = dist
def __lt__(self, n):
return self.dist < n.dist
MAX = 1 << 31
MOD = int(1e9 + 7)
# 处理边
node2edges = defaultdict(list)
for edge in edges:
node2edges[edge[0]].append((edge[1], edge[2]))
node2edges[edge[1]].append((edge[0], edge[2]))
# 求单源最短路
dis = [MAX for _ in range(n+1)]
dis[n] = 0
o = set([i for i in range(1, n)])
min_node = n
minheap = [Node(i, dis[i]) for i in range(1, n)]
heapify(minheap)
while len(o) > 0:
# 更新选中节点的邻接节点
for node, dist in node2edges[min_node]:
dist += dis[min_node]
if dist < dis[node]:
dis[node] = dist
heappush(minheap, Node(node, dist))
# 选择距离最小节点
while True:
n0 = heappop(minheap)
min_node, dist = n0.node, n0.dist
if min_node in o: break
o.remove(min_node)
dp = [0 for _ in range(n+1)]
def dfs(node):
if node == n:
return 1
if dp[node] > 0: return dp[node]
ret = 0
edges = node2edges[node]
for new_node, dist in edges:
if dis[node] > dis[new_node]:
ret += dfs(new_node)%MOD
dp[node] = ret%MOD
return ret%MOD
ret = dfs(1)%MOD
return ret
经过新的题发现,记忆化搜索只需一行代码。真牛啊。
@cache
def dfs(node):
if node == n:
return 1
ret = 0
edges = node2edges[node]
for new_node, dist in edges:
if dis[node] > dis[new_node]:
ret += dfs(new_node)%MOD
return ret%MOD