这里给出两种实现方法:动态规划 和 dijkstra单源最短路径算法。
class FindCheapestPrice:
"""
787. K 站中转内最便宜的航班
https://leetcode.cn/problems/cheapest-flights-within-k-stops/
"""
def solution(self, n: int, flights: List[List[int]], src: int, dst: int, k: int) -> int:
# 从起点 src 出发,k 步之内(⼀步就是⼀条边)到达节点 s 的最⼩路径权重为 dp(s, k)。
# 将中转站个数转化成边的条数
k += 1
self.src = src
self.dst = dst
self.indegree = dict()
self.memo = [[-666 for _ in range(k+1)] for _ in range(n)]
for f in flights:
start = f[0]
end = f[1]
weight = f[2]
if end not in self.indegree:
self.indegree[end] = []
self.indegree.get(end).append([start, weight])
return self.dp(dst, k)
def dp(self, s, k):
"""
定义:从 src 出发,k 步之内到达 s 的最短路径权重
:param s:
:param k:
:return:
"""
if s == self.src:
return 0
if k == 0:
return -1
if self.memo[s][k] != -666:
return self.memo[s][k]
res = float('inf')
if s not in self.indegree:
return -1
for v in self.indegree[s]:
start = v[0]
price = v[1]
# 从 src 到达相邻的⼊度节点所需的最短路径权重
subproblem = self.dp(start, k-1)
# 跳过⽆解的情况
if subproblem != -1:
res = min(res, subproblem+price)
res = -1 if res == float('inf') else res
self.memo[s][k] = res
return res
def solution2(self, n: int, flights: List[List[int]], src: int, dst: int, k: int) -> int:
"""
应用dijkstra算法,和标准的dijkstra算法区别是,这里加入了最多中转次数k的限制,
所以加入队列的条件会有变化。
:param n:
:param flights:
:param src:
:param dst:
:param k:
:return:
"""
# 将中转站个数转化成边的条数
k += 1
self.src = src
self.dst = dst
self.graph = [[] for _ in range(n)]
for f in flights:
start = f[0]
end = f[1]
weight = f[2]
self.graph[start].append([end, weight])
return self.dijkstra(src, dst, k, self.graph)
def dijkstra(self, src, dst, k, graph: List[List[int]]):
"""
输⼊⼀个起点 src,计算从 src 到其他节点的最短距离
:param start:
:param graph:
:return:
"""
import heapq
v = len(graph)
# dp table,distTo[i]可理解为节点s到节点i的最短路劲,后续要不停地更新该表
# 定义:从起点 src 到达节点 i 的最短路径权重为 distTo[i]
# 定义:从起点 src 到达节点 i 的最⼩权重路径⾄少要经过 nodeNumTo[i] 个节点
distTo = [float('inf')] * v
nodeNumTo = [float('inf')] * v
# base case
distTo[src] = 0
nodeNumTo[src] = 0
min_heap = []
# 从起点s开始BFS
heapq.heappush(min_heap, self.State(src, 0, 0))
while min_heap:
curState = heapq.heappop(min_heap)
curNodeID = curState.id
curCostFromStart = curState.costFromStart
curNodeNumFromStart = curState.nodeNumFromStart
# 找到最短路径
if curNodeID == dst:
return curCostFromStart
# 中转次数耗尽
if curNodeNumFromStart == k:
continue
# 遍历curNodeID的相邻节点
for neighbor in graph[curNodeID]:
nextNodeID = neighbor[0]
costToNextNode = curCostFromStart + neighbor[1]
nextNodeNumFromSrc = curNodeNumFromStart + 1
# 剪枝,如果中转次数更多,花费还更⼤,那必然不会是最短路径
if costToNextNode > distTo[nextNodeID] and \
nextNodeNumFromSrc > nodeNumTo[nextNodeID]:
continue
if costToNextNode < distTo[nextNodeID]:
# 更新dp table
distTo[nextNodeID] = costToNextNode
nodeNumTo[nextNodeID] = nextNodeNumFromSrc
# 将该邻居节点加入优先级队列
heapq.heappush(min_heap, self.State(nextNodeID, costToNextNode, nextNodeNumFromSrc))
return -1
class State:
def __init__(self, id, costFromStart, nodeNumFromStart):
"""
:param id: 图节点的 id
:param costFromStart: 从 src 节点到当前节点的花费
:param nodeNumFromStart: 从 src 节点到当前节点经过的节点个数
"""
self.id = id
self.costFromStart = costFromStart
self.nodeNumFromStart = nodeNumFromStart
def __lt__(self, other):
if self.costFromStart < other.costFromStart:
return True
return False