图网络(GNN) DGL库的使用

从零开始的DGL库学习

dgl.DGLGraph

最基础的API,创建一个图
基础图类。
图存储节点、边以及它们的特性。
DGL图总是有方向的。无向图可以用两条双向边表示。
节点由从零开始的连续整数标识。
边可以由两个端点(u, v)或添加边时分配的整数id指定。
边缘ID是按照加法的顺序自动分配的,即第一个被添加的边缘的ID是0,第二个是1,以此类推。
节点和边缘特征以字典的形式存储,从特征名到特征数据(以张量的形式)。

import dgl
import torch as th
G = dgl.DGLGraph()
G.add_nodes(3)
G.ndata['装机容量'] = th.zeros((3, 5))  
G.ndata['辐照度'] = th.zeros((3, 5))   # 每个点都可以有多个特征
G.ndata
'''
输出:
{'装机容量': tensor([[0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0.]]), 
 '辐照度':   tensor([[0., 0., 0., 0., 0.],
        		    [0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0.]])}
'''
G.nodes[[1, 2]].data['装机容量'] = th.ones((2, 5))
G.nodes[2].data['装机容量'] = th.ones((1, 5))*2
G.ndata  # 查看点的数据
'''
输出
{'装机容量': tensor([[0., 0., 0., 0., 0.],
        		    [1., 1., 1., 1., 1.],
       			    [2., 2., 2., 2., 2.]]),
'辐照度':    tensor([[0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0.]])}
'''
# G.add_edges([0, 1], 2)  # 0->2, 1->2  # ,间隔的是边的起点和终点
G.add_edges( 0,[1, 2])    # 0->1, 0->2 
G.edata['y'] = th.zeros((2, 4))         # 边同样可以有多个特征
G.edata
'''
{'y': tensor([[0., 0., 0., 0.],
              [0., 0., 0., 0.]])}
'''
# G.edges[0].data['y'] += 2.  点和边也是有编号的,可以按照list取,对某个特征进行运算

更新节点特性的一个常见操作是消息传递,其中源节点通过边缘向目的地发送消息。
使用:class: ’ DGLGraph ',我们可以使用:func: ’ send '和:func: ’ recv '来实现这一点。
在下面的示例中,边的目的节点(终点)的装机容量值+7 作为消息添加到它们的节点特性中,并将消息发送到目的地。

# 定义消息发送函数 返回字典结构 {消息的名称:消息本体(tensor的值形式)}
def send_source(edges): 
    print(edges.dst['装机容量'])
    return {'m': edges.dst['装机容量'] + 7}
# 将定义的函数设置为默认消息函数  UDF(user define func)
G.register_message_func(send_source)
# 所有的边都发送消息, 消息已经上了路,发送到了各自的终点,终点接收不接收(v=?),怎么接收(UDF_RECV)是recv的事情
G.send(G.edges())   #    G.edges()   tensor([0, 0]), tensor([1, 2])

就像您需要到邮箱中检索邮件一样
节点还需要接收消息并可能更新其功能。

def simple_reduce(nodes): 
    print(nodes.mailbox['m'].sum(1)) 
    ''' tensor([[8., 8., 8., 8., 8.],
                [9., 9., 9., 9., 9.]])''' 
     # 接受点v设置成了1 2 ,则消息的值就是点1 2的装机容量+7 , sum(1)去除维度问题
     # 返回值字典结构{点特征的名称:对消息的处理(tensor的值形式)}
    return {'辐照度': nodes.data['装机容量'] + nodes.mailbox['m'].sum(1)}
    #  将辐照度设置成为2倍的装机容量 + 7
G.register_reduce_func(simple_reduce)
G.recv(v=[1,2])   # 终点为1 2的点接收传入的信息

print(G.edges(),'\n', G.edata,'\n', G.ndata)
'''(tensor([0, 0]), tensor([1, 2])) 
   {'y': tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]])} 
   {'装机容量': tensor([[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.]]),
    '辐照度': tensor([[ 0.,  0.,  0.,  0.,  0.],
        [ 9.,  9.,  9.,  9.,  9.],
        [11., 11., 11., 11., 11.]])}
'''

0不是任何边的终点,所以0不会有传入消息

  1. G.recv(v=[0]) 若改成0 ,则会tensor全为0
    如果所有v都没有传入消息,则将降级为apply_nodes()。
  2. G.recv(v=[0,1]) 提示使用默认的初始化器
    如果一些v没有传入消息,它们的新特性值将由列初始化器计算(参见set_n_initializer())。将推断特征形状和d类型。
    docs
  • 11
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值