网络算法——基于堆和循环桶的Dijkstra算法

1 实现概述

实验实现了普通和使用堆或者循环桶数据结构的Dij算法。堆和桶的结构由自己定义。图的定义使用库networkx,能方便的为节点和边赋予信息。在测试上使用随机数通过节点,边和权重的随机生成进行验证,计算算法运行时间进行比对。

在堆的结构上,使用[{‘key’:weight,’info’:[]},……]的格式,其中的key作为更新位置的依据,info则存储节点和临界边的信息,在更新值的时候查询__index得到节点的在数组中的位置,然后视情况上下调整而非重新建堆。

在桶的结构上同样使用[{‘key’:weight,’info’:[]},……]的格式,其中的key作为更新位置的依据,info则存储节点和临界边的信息。__index记录节点的当前位置。

2 效果演示

输入:

s a 2;s e 4;a e 1;a b 4;a d 2;b c 2;d b 3;d c 2;e d 3

表示为图像如下所述:

img

输出:

----------Dij2----------

s->e:weight:3 s->a->e

s->a:weight:2 s->a

s->b:weight:6 s->a->b

s->d:weight:4 s->a->d

s->c:weight:6 s->a->d->c

----------Dij3----------

s->e:weight:3 s->a->e

s->a:weight:2 s->a

s->b:weight:6 s->a->b

s->d:weight:4 s->a->d

s->c:weight:6 s->a->d->c

----------Dij4----------

s->e:weight:3 s->a->e

s->a:weight:2 s->a

s->b:weight:6 s->a->b

s->d:weight:4 s->a->d

s->c:weight:6 s->a->d->c

3 堆和桶类定义

类名:Heap

类属性:

  • self.heap = list() 记录堆的主结构

  • self.__index=dict() 记录节点的位置信息

    类方法:

类方法作用
def _init_(self):初始化类
def _str_(self):返回类的信息
def Up(self, i: int):将指定节点向上调整
def Down(self, i: int):将指定节点向下调整
def Insert(self, weight: int, node: str, neighbor: str):插入新节点
def GetIndex(self, find: str):根据给定的信息查找节点位置
def GetMin(self):获取堆的最小值
def Change(self, weight: int, node: str, neighbor: str):更改指定节点的值

类名:CBucket

类属性:

  • self.bucket = [[] for _ in range(num)] 循环桶主体

  • self.__index = dict() 存储节点的当前位置

  • self.__point = 0 指针的位置

  • self.__C = num – 1 创建桶的最大权重

类方法:

类方法作用
def _init_(self):初始化桶
def _str_(self):返回桶的信息
def Insert(self, data: dict):根据传入数据的key值作为weight来决定桶的位置插入桶
def Change(self, node: str, new_weight, new_neighbor):更新桶中信息,先找到原位置,删除后再插入
def IsNull(self) -> bool:判断桶是否为空
def GetWeight(self, node: str):返回查询节点的权重信息
def FindMin(self) -> dict:从指针的起始位置开始寻找最小非空桶

4 其他函数定义

函数作用
def Show_path(start: str, dist: dict, path: dict):根据提供的信息打印路径
def Dij2(g: nx.DiGraph, start: str):未经优化的Dij算法
def Dij3(g: nx.DiGraph, start: str):使用堆的Dij算法
def Dij4(g: nx.DiGraph, start: str):使用桶的Dij算法
def Test(test_times: int, max_nodes: int, max_weights):运行时间测试,参数分别制定测试次数,最大节点数目和最大权重值。

5 时间测试

在这部分中,通过指定测试集数量,最大节点数和最大权重数目,我们使用random模块随机生成节点,边和权重,通过networkx库的内置函数is_weakly_connected检验随机生存的图是否弱联通,若是则放入三个算法,最后分别统计运行时间,部分运行结果如下所示:

image-20230624155639043

我们统计部分的运行时间如下表所示:

时间/1000次50100200500100020005000
Dij20.1350.4061.3588.85333.434134.9991069.492
Dij30.360.8572.2639.73128.05591.134656.423
Dij40.2420.6431.90210.89138.752122.165847.937

绘制图像如下所述:

image-20230624155646493

由于桶的索引设计的仍然不够简洁,所以在运行时间上,仍然不如堆算法,也可能是由于测试更快,我们所用边的范围在 [ n , n ( n − 1 ) 4 ] [n,\frac{n(n-1)}{4}] [n,4n(n1)]之间,而导致了桶算法的运行时间不如堆,因为经过推到,堆和桶算法的理论时间复杂度为$O(mlogn) 和 和 O(m+nc)$,在边数较多的时候,桶算法会用更大的优势。

6 遇到的疑惑

首先是选择怎样的存储结构的问题,之前的proj1逻辑上使用了邻接表,实际上借助字典的数据结构实现。但如果为图增加权重信息,又会重新设计,并且也不便于以后项目的进行。

所以选择了现成的网络图分析的库networkx,省去了考虑图的构建,能更容易的关注于算法本身,方便后续测试中图的生成。

再者是堆结构和桶结构的设计,针对堆的结构,我们对pro2中的堆结构进行了优化,使用一个数组存储,内部用dict结构作为元素,根据字典中的key来作为位置的依据,而info存储节点和邻接信息。桶我们也使用这样的设计。

在测试时我们发现,未经优化的原始Dij2运行时间竟然小于堆算法和桶算法。经过反复的查找,我们修改了堆的索引方式,修改了Dij4中的邻居节点更新方式,虽然仍得不到比Dij2明显优越的运行时间,但是也体会了堆和桶的特点,只剩下细节的打磨。

7 源代码

import networkx as nx
import time
import random


class Heap():
    def __init__(self):
        self.heap = list()
        self.__index = dict()

    def __str__(self):
        ans = self.heap.__str__() + '\n' + self.__index.__str__()
        return ans

    def Up(self, i: int):
        while True:
            if i < 2:
                return
            if self.heap[i // 2 - 1]['key'] > self.heap[i - 1]['key']:
                self.heap[i // 2 - 1], self.heap[i - 1] = self.heap[i - 1], self.heap[i // 2 - 1]
                ori = self.heap[i // 2 - 1]['info'][0]
                nex = self.heap[i - 1]['info'][0]
                self.__index[ori], self.__index[nex] = self.__index[nex], self.__index[ori]
                i = i // 2
            else:
                return

    def Down(self, i: int):
        n = len(self.heap)
        while True:
            if i > n // 2:
                return

            temp = 2 * i + 1 if 2 * i + 1 - 1 <= n - 1 and self.heap[2 * i - 1]['key'] > self.heap[(2 * i + 1) - 1][
                'key'] else 2 * i
            if self.heap[temp - 1]['key'] < self.heap[i - 1]['key']:
                self.heap[temp - 1], self.heap[i - 1] = self.heap[i - 1], self.heap[temp - 1]
                ori = self.heap[temp - 1]['info'][0]
                nex = self.heap[i - 1]['info'][0]
                self.__index[ori], self.__index[nex] = self.__index[nex], self.__index[ori]
                i = temp
            else:
                return

    def Insert(self, weight: int, node: str, neighbor: str):
        self.heap.append({'key': weight, 'info': [node, neighbor]})
        self.__index.update({node: len(self.heap) - 1})
        # 向上调整
        self.Up(len(self.heap))

    def GetMin(self):
        ans = self.heap[0]
        ans_node = ans['info'][0]
        del self.__index[ans_node]
        self.heap[0] = self.heap[-1]
        self.heap.pop()
        if len(self.heap) != 0:
            new_fir_node = self.heap[0]['info'][0]
            self.__index[new_fir_node] = 0
        self.Down(1)
        return ans

    def GetIndex(self, find: str):
        try:
            return self.__index[find]
        except:
            return None

    def Change(self, weight: int, node: str, neighbor: str):
        i = self.GetIndex(node)
        self.heap[i]['key'] = weight
        self.heap[i]['info'][1] = neighbor
        if (i + 1) // 2 - 1 >= 0 and self.heap[i]['key'] < self.heap[(i + 1) // 2 - 1]['key']:
            # 变更后的值 小于父节点的值
            # 向上调整
            self.Up(i + 1)
        else:
            self.Down(i + 1)


class CBucket():
    def __init__(self, num: int):
        self.bucket = [[] for _ in range(num)]  # 桶的主结构
        self.__index = dict()  # 存储节点的当前位置
        self.__point = 0  # 指针的位置
        self.__C = num - 1

    def __str__(self):
        return f'{self.bucket.__str__()}\n{self.__index.__str__()}\npoint:{self.__point}'

    def Insert(self, data: dict):
        # info中的key作为weight来决定桶的位置 info存储节点的名称和临界的节点
        index = data['key'] % (self.__C + 1)
        self.bucket[index].append(data)
        self.__index.update({data['info'][0]: index})

    def FindMin(self) -> dict:
        # 从指针的起始位置开始寻找最小非空桶
        begin = self.__point
        while True:
            if not not self.bucket[begin]:
                # 当找到了第一个非空的桶
                ans = self.bucket[begin].pop()
                # 弹出的同时删除索引
                node = ans['info'][0]
                del self.__index[node]
                return ans
            # 否则更新指针
            begin = (begin + 1) % (self.__C + 1)
            self.__point = begin

    def Change(self, node: str, new_weight, new_neighbor):
        index = self.__index[node]
        # 找到对应的桶 删除其中的内容
        # 如果桶内有多个节点的话
        for i in range(len(self.bucket[index])):
            if self.bucket[index][i]['info'][0] == node:
                del self.bucket[index][i]
                break
        # 删除index中的索引
        self.__index.pop(node)
        # 插入新的桶
        self.Insert({'key': new_weight, 'info': [node, new_neighbor]})

    def IsNull(self) -> bool:
        return True if len(self.__index) == 0 else False

    def GetWeight(self, node: str):
        # 查询的节点可能不存在 那就返回None
        try:
            index = self.__index[node]
            return self.bucket[index][0]['key']
        except:
            return None


def Show_path(start: str, dist: dict, path: dict):
    nodes = set(dist.keys()) - {start}
    for end in nodes:
        print(f'{start}->{end}:weight:{dist[end]}', end='\t')

        cur_path = path[end]
        ans = f'{end}'
        while cur_path is not None:
            ans = f'{cur_path}->' + ans
            cur_path = path[cur_path]
        print(ans)


def Dij2(g: nx.DiGraph, start: str):
    nodes = set(g.nodes)
    dist = dict()
    path = dict()
    for item in g.nodes:
        dist.update({item: float('inf')})
        path.update({item: None})
    joined = set()
    dist[start] = 0
    while joined != nodes:
        unjoined = nodes - joined
        join_weight = float('inf')
        join_node = None
        # 在没有加入的节点中寻找最小距离的节点
        for item in unjoined:
            if dist[item] <= join_weight:
                join_weight = dist[item]
                join_node = item
        joined.add(join_node)
        # 更新临界边距离
        # print(unjoined)
        neighbors = g.neighbors(join_node)
        for item in neighbors:
            weight = g.get_edge_data(join_node, item)['weight']
            if dist[join_node] + weight < dist[item]:
                dist[item] = dist[join_node] + weight
                path[item] = join_node
    return dist, path


def Dij3(g: nx.DiGraph, start: str):
    nodes = set(g.nodes)
    dist = dict()
    path = dict()
    heap = Heap()
    heap.Insert(0, start, None)
    for item in nodes - {start}:
        heap.Insert(float('inf'), item, None)

    joined = set()
    while joined != nodes:
        pop = heap.GetMin()
        join_node = pop['info'][0]
        join_path = pop['info'][1]
        join_weight = pop['key']
        # 将这些点加入结果
        joined.add(join_node)
        dist.update({join_node: join_weight})
        path.update({join_node: join_path})
        # 更新临界边距离
        neighbors = g.neighbors(join_node)
        for item in set(neighbors) - joined:
            weight = g.get_edge_data(join_node, item)['weight']
            index = heap.GetIndex(item)
            dist_in_heap = heap.heap[index]['key']
            if dist[join_node] + weight < dist_in_heap:
                heap.Change(dist[join_node] + weight, item, join_node)
    return dist, path


def Dij4(g: nx.DiGraph, start: str):
    global a, b, c, d
    dist = dict()
    path = dict()
    C = 0
    # 寻找最大权重
    for value1 in g.adj.values():
        for value2 in value1.values():
            temp = value2['weight']
            if C < temp:
                C = temp

    cb = CBucket(C + 1)
    cb.Insert({'key': 0, 'info': [start, None]})
    joined = set()
    # 桶不空
    while not cb.IsNull():

        pop = cb.FindMin()
        join_node = pop['info'][0]
        join_path = pop['info'][1]
        join_weight = pop['key']
        # 将这些点加入结果
        dist.update({join_node: join_weight})
        path.update({join_node: join_path})
        joined.add(join_node)

        neighbors = g.neighbors(join_node)
        for item in set(neighbors) - joined:
            old_weight = cb.GetWeight(item)
            if old_weight is None:
                old_weight = float('inf')
            new_weight = g.get_edge_data(join_node, item)['weight'] + join_weight

            if new_weight < old_weight:
                if old_weight == float('inf'):
                    # 说明这个点不在桶中但是还没标记
                    cb.Insert({'key': new_weight, 'info': [item, join_node]})
                else:
                    cb.Change(item, new_weight, join_node)
    return dist, path


def Test(test_times: int, max_nodes: int, max_weights):
    print(f'测试集个数:{test_times}\t最大节点个数:{max_nodes}\t最大权重值:{max_weights}\t')
    real_test = 0
    real_nodes = 0
    real_vexs = 0
    time_dij2 = 0
    time_dij3 = 0
    time_dij4 = 0
    finish = 1
    for _ in range(test_times):
        print(f'\r' + '=' * int(50 * finish / test_times) + f'=>{finish}/{test_times}', end='')
        finish += 1
        # 随机生成节点个数
        now_nodes = random.randint(10, max_nodes)
        # 随机生成边的个数 最少为n-1 最多为n(n-1)/2 尽可能保持联通 可以有重边
        now_vexs = random.randint(now_nodes, int(now_nodes * (now_nodes - 1) / 4))
        # 调用库随机生成一个图
        g = nx.generators.random_graphs.gnm_random_graph(now_nodes, now_vexs, directed=True)
        # 判断图的若连通性
        if nx.is_weakly_connected(g):
            for u, v in g.edges:
                g[u][v]['weight'] = random.randint(1, max_weights)

            real_test += 1
            real_nodes += now_nodes
            real_vexs += now_vexs
            t1 = time.time()
            Dij2(g, list(g.nodes)[0])
            t2 = time.time()
            Dij3(g, list(g.nodes)[0])
            t3 = time.time()
            Dij4(g, list(g.nodes)[0])
            t4 = time.time()
            time_dij2 += t2 - t1
            time_dij3 += t3 - t2
            time_dij4 += t4 - t3

    print(
        f'\n有效测试数量{real_test}\n总节点数{real_nodes}\n总边数{real_vexs}\nDij2用时{time_dij2}\tDij3用时{time_dij3}\nDij4用时{time_dij4}')


if __name__ == '__main__':
    # g = nx.DiGraph()
    # g.add_weighted_edges_from(
    #     [('s', 'a', 2), ('s', 'e', 4), ('a', 'e', 1), ('a', 'b', 4), ('a', 'd', 2), ('b', 'c', 2), ('d', 'b', 3),
    #      ('d', 'c', 2), ('e', 'd', 3)])
    # g.add_weighted_edges_from([('s', 'b', 1), ('b', 'c', 4), ('c', 's', 2)])
    #
    # dist2, path2 = Dij2(g, 's')
    # Show_path('s', dist2, path2)
    #
    # dist3, path3 = Dij3(g, 's')
    # Show_path('s', dist3, path3)
    #
    # dist4, path4 = Dij4(g, 's')
    # Show_path('s', dist4, path4)
    #
    # Test(10000, 50, 100)
    # 控制台
    while True:
        print('1\t示例执行')
        print('2\t自定义图执行')
        print('3\t测试函数')
        print('0\t退出')
        choose = int(input('请输入要执行的选项:'))
        if choose == 1:
            g = nx.DiGraph()
            g.add_weighted_edges_from(
                [('s', 'a', 2), ('s', 'e', 4), ('a', 'e', 1), ('a', 'b', 4), ('a', 'd', 2), ('b', 'c', 2),
                 ('d', 'b', 3),
                 ('d', 'c', 2), ('e', 'd', 3)])
            print('-' * 10 + 'Dij2' + '-' * 10)
            dist2, path2 = Dij2(g, 's')
            Show_path('s', dist2, path2)

            print('-' * 10 + 'Dij3' + '-' * 10)
            dist3, path3 = Dij3(g, 's')
            Show_path('s', dist3, path3)

            print('-' * 10 + 'Dij4' + '-' * 10)
            dist4, path4 = Dij4(g, 's')
            Show_path('s', dist4, path4)

        elif choose == 2:
            print('请按照以下格式输入边')
            print('A B 9;A C 3')
            print('回车结束输入')
            g = nx.DiGraph()
            try:
                temp = input()
                temp = temp.split(';')
                for item in temp:
                    a = item.split(' ')
                    g.add_edge(a[0], a[1], weight=int(a[2]))
            except:
                pass
            start = input('请输入起始节点:')
            print('输入结束')
            if nx.is_weakly_connected(g):
                print('-' * 10 + 'Dij2' + '-' * 10)
                dist2, path2 = Dij2(g, start)
                Show_path('s', dist2, path2)

                print('-' * 10 + 'Dij3' + '-' * 10)
                dist3, path3 = Dij3(g, start)
                Show_path('s', dist3, path3)

                print('-' * 10 + 'Dij4' + '-' * 10)
                dist4, path4 = Dij4(g, start)
                Show_path('s', dist4, path4)
            else:
                print('请您输入弱连通图……')
        elif choose == 3:
            print('友情提示 虽然算法不慢 但是networkx库跑的慢 增加边需要很长时间\n测试量和节点不要太大\n测试结果放在报告了\n您可以跑一下试试看~')
            test_times = int(input('测试集个数'))
            max_nodes = int(input('最大节点个数 n>2'))
            max_weights = int(input('最大权重值'))
            Test(test_times, max_nodes, max_weights)
        elif choose == 0:
            break
        else:
            print('您输错了……')
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值