Dijkstra最短路径算法本质上也是广度优先搜索。
考虑到算法的remove_smallest
操作,使用最小堆提升效率。
class Point(object):
def __init__(self, index, value):
self.index = index
self.value = value
def __lt__(self, other):
return self.value < other.value
class Heap(object):
def __init__(self):
self.__keys = [Point(None, None)]
self.__size = 0
def add(self, index, value):
self.__size += 1
point = Point(index, value)
self.__keys.append(point)
self.__go_up(self.__size)
def __go_up(self, index):
if index == 1:
return
parent_index = index // 2
if self.__keys[index] < self.__keys[parent_index]:
self.__swap(index, parent_index)
self.__go_up(parent_index)
def __swap(self, index1, index2):
self.__keys[index1], self.__keys[index2] = self.__keys[index2], self.__keys[index1]
def remove_smallest(self):
self.__size -= 1
if self.__size > 0:
self.__keys[1] = self.__keys.pop()
else:
self.__keys.pop()
self.__go_down(1)
def __go_down(self, index):
left_child_index = 2 * index
right_child_index = 2 * index + 1
if right_child_index <= self.__size:
if self.__keys[index] > min(self.__keys[left_child_index], self.__keys[right_child_index]):
child_index_swap = left_child_index if self.__keys[left_child_index] < self.__keys[
right_child_index] else right_child_index
self.__swap(index, child_index_swap)
self.__go_down(child_index_swap)
else:
return
elif left_child_index <= self.__size:
if self.__keys[index] > self.__keys[left_child_index]:
self.__swap(index, left_child_index)
self.__go_down(left_child_index)
else:
return
else:
return
def get_size(self):
return self.__size
def get_smallest(self):
return self.__keys[1].index, self.__keys[1].value
def update(self, point_index, distance):
for index, point in enumerate(self.__keys):
if point.index == point_index:
point.value = distance
self.__go_up(index)
self.__go_down(index)
break
class Graph(object):
def __init__(self, points_num, is_directed):
self.__points_num = points_num
self.__adj = [{} for _ in range(points_num)]
self.__directed = is_directed
def add_edge(self, point1, point2, distance):
self.__adj[point1][point2] = distance
if not self.__directed:
self.__adj[point2][point1] = distance
def get_adj(self, point):
return self.__adj[point]
def get_point_nums(self):
return self.__points_num
def get_distance(self, point1, point2):
return self.__adj[point1][point2]
class Dijkstra(object):
def __init__(self, graph, start):
self.__graph = graph
self.__start = start
self.__distance = [float('inf') for _ in range(graph.get_point_nums())]
self.__distance[start] = 0
self.__edge_to = [None for _ in range(graph.get_point_nums())]
self.__fringe = Heap()
for point in range(self.__graph.get_point_nums()):
if point != self.__start:
self.__fringe.add(point, float('inf'))
else:
self.__fringe.add(point, 0)
self.bfs()
def bfs(self):
while self.__fringe.get_size() > 0:
point, _ = self.__fringe.get_smallest()
self.__fringe.remove_smallest()
for adj_point in self.__graph.get_adj(point):
distance = self.__graph.get_distance(point, adj_point)
if self.__distance[point] + distance < self.__distance[adj_point]:
self.__distance[adj_point] = self.__distance[point] + distance
self.__edge_to[adj_point] = point
self.__fringe.update(adj_point, self.__distance[adj_point])
def get_path_to(self, point):
path = [point]
while point != self.__start:
point = self.__edge_to[point]
path.insert(0, point)
return path
def get_distance(self, point):
return self.__distance[point]
graph = Graph(7, True)
graph.add_edge(0, 1, 2)
graph.add_edge(0, 2, 1)
graph.add_edge(1, 2, 5)
graph.add_edge(1, 3, 11)
graph.add_edge(1, 4, 3)
graph.add_edge(2, 5, 15)
graph.add_edge(3, 4, 2)
graph.add_edge(4, 2, 1)
graph.add_edge(4, 5, 4)
graph.add_edge(4, 6, 5)
graph.add_edge(6, 3, 1)
graph.add_edge(6, 5, 1)
dijkstra = Dijkstra(graph, 0)
print(dijkstra.get_path_to(6))
print(dijkstra.get_distance(6))
给节点添加预设距离可提升搜索效率,但若预设距离过大则可能找不到最优解。