Dijkstra最短路径算法

  Dijkstra最短路径算法本质上也是广度优先搜索。
  考虑到算法的remove_smallest操作,使用最小堆提升效率。

class Point(object):
    def __init__(self, index, value):
        self.index = index
        self.value = value

    def __lt__(self, other):
        return self.value < other.value


class Heap(object):
    def __init__(self):
        self.__keys = [Point(None, None)]
        self.__size = 0

    def add(self, index, value):
        self.__size += 1
        point = Point(index, value)
        self.__keys.append(point)
        self.__go_up(self.__size)

    def __go_up(self, index):
        if index == 1:
            return
        parent_index = index // 2
        if self.__keys[index] < self.__keys[parent_index]:
            self.__swap(index, parent_index)
            self.__go_up(parent_index)

    def __swap(self, index1, index2):
        self.__keys[index1], self.__keys[index2] = self.__keys[index2], self.__keys[index1]

    def remove_smallest(self):
        self.__size -= 1
        if self.__size > 0:
            self.__keys[1] = self.__keys.pop()
        else:
            self.__keys.pop()
        self.__go_down(1)

    def __go_down(self, index):
        left_child_index = 2 * index
        right_child_index = 2 * index + 1
        if right_child_index <= self.__size:
            if self.__keys[index] > min(self.__keys[left_child_index], self.__keys[right_child_index]):
                child_index_swap = left_child_index if self.__keys[left_child_index] < self.__keys[
                    right_child_index] else right_child_index
                self.__swap(index, child_index_swap)
                self.__go_down(child_index_swap)
            else:
                return
        elif left_child_index <= self.__size:
            if self.__keys[index] > self.__keys[left_child_index]:
                self.__swap(index, left_child_index)
                self.__go_down(left_child_index)
            else:
                return
        else:
            return

    def get_size(self):
        return self.__size

    def get_smallest(self):
        return self.__keys[1].index, self.__keys[1].value

    def update(self, point_index, distance):
        for index, point in enumerate(self.__keys):
            if point.index == point_index:
                point.value = distance
                self.__go_up(index)
                self.__go_down(index)
                break


class Graph(object):
    def __init__(self, points_num, is_directed):
        self.__points_num = points_num
        self.__adj = [{} for _ in range(points_num)]
        self.__directed = is_directed

    def add_edge(self, point1, point2, distance):
        self.__adj[point1][point2] = distance
        if not self.__directed:
            self.__adj[point2][point1] = distance

    def get_adj(self, point):
        return self.__adj[point]

    def get_point_nums(self):
        return self.__points_num

    def get_distance(self, point1, point2):
        return self.__adj[point1][point2]


class Dijkstra(object):
    def __init__(self, graph, start):
        self.__graph = graph
        self.__start = start
        self.__distance = [float('inf') for _ in range(graph.get_point_nums())]
        self.__distance[start] = 0
        self.__edge_to = [None for _ in range(graph.get_point_nums())]
        self.__fringe = Heap()
        for point in range(self.__graph.get_point_nums()):
            if point != self.__start:
                self.__fringe.add(point, float('inf'))
            else:
                self.__fringe.add(point, 0)
        self.bfs()

    def bfs(self):
        while self.__fringe.get_size() > 0:
            point, _ = self.__fringe.get_smallest()
            self.__fringe.remove_smallest()
            for adj_point in self.__graph.get_adj(point):
                distance = self.__graph.get_distance(point, adj_point)
                if self.__distance[point] + distance < self.__distance[adj_point]:
                    self.__distance[adj_point] = self.__distance[point] + distance
                    self.__edge_to[adj_point] = point
                    self.__fringe.update(adj_point, self.__distance[adj_point])

    def get_path_to(self, point):
        path = [point]
        while point != self.__start:
            point = self.__edge_to[point]
            path.insert(0, point)
        return path

    def get_distance(self, point):
        return self.__distance[point]

graph = Graph(7, True)
graph.add_edge(0, 1, 2)
graph.add_edge(0, 2, 1)
graph.add_edge(1, 2, 5)
graph.add_edge(1, 3, 11)
graph.add_edge(1, 4, 3)
graph.add_edge(2, 5, 15)
graph.add_edge(3, 4, 2)
graph.add_edge(4, 2, 1)
graph.add_edge(4, 5, 4)
graph.add_edge(4, 6, 5)
graph.add_edge(6, 3, 1)
graph.add_edge(6, 5, 1)

dijkstra = Dijkstra(graph, 0)
print(dijkstra.get_path_to(6))
print(dijkstra.get_distance(6))

  给节点添加预设距离可提升搜索效率,但若预设距离过大则可能找不到最优解。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值