RGCN layer implementation in pytorch

本文详细介绍了如何在PyTorch中实现RGCN(关系图卷积网络)层。首先,将多关系图拆分为多个同质图,每个图对应一种边类型。接着,为每种边类型初始化一个GCN层,学习特定关系的节点特征。在前向传播过程中,通过线性变换处理节点特征,并结合邻接矩阵进行消息传递和聚合。最后,通过加权求和与节点自身的特征相结合,得到最终的输出特征。RGCN在处理图结构数据时,能够捕捉节点间复杂的关系信息。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

RGCN layer implementation in torch

  1. 将含有 N 种edge type的多关系图,拆成 N 个 同质图,每个同质图只包含一种edge type及该edge type所连接到的节点
  2. 对每一种edgetype初始化一个GCN层,每个GCN学习一种关系的同质图
  3. 对于每个节点,聚合各个关系下学习到的节点特征(aggregation = sum+mean)
class GraphConvolutionLayer(Module): # RGCN
    def __init__(self, in_features, out_features, activation, edge_type_num, dropout_rate=0.):
        super(GraphConvolutionLayer, self).__init__()
        self.edge_type_num = edge_type_num
        self.out_features = out_features
        self.edgeType_linears = nn.ModuleList()        
        #- self.edgeType_linears:
        #- GCN卷积层中的W参数,其实就是对node-feature的线性变换
        #- 在RGCN中,对每一种relation,分别定义一个GCN,实质就是定义该GCN的W参数,
        #-- 有R种relation,就定义R个W,用于在不同关系的GCN中,学习该关系下node-feature
        
        for _ in range(self.edge_type_num):
            self.edgeType_linears.append(nn.Linear(in_features, out_features))
        self.linear_2 = nn.Linear(in_features, out_features)
        self.activation = activation
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, n_tensor, adj_tensor, h_tensor=None):
        # adj_tensor - torch.Size([batchsize, num_edgeTypes, num_nodes, num_nodes])
        # n_tensor - torch.Size([batchsize, num_nodes, node_feat_dim])
        if h_tensor is not None:
            node_annotations = torch.cat((n_tensor, h_tensor), -1)
        else:
            node_annotations = n_tensor
		# h_tensor 用于扩展RGCN的层数,搭建两层或三层RGCN时,h_tensor就是上一层RGCN输出的隐层nodes-embeddings

        node_feat_all_edge_type = []
        for edge_type in range(self.edge_type_num):
            node_feat_single_edge_type = self.edgeType_linears[edge_type](node_annotations)
            node_feat_all_edge_type.append(node_feat_single_edge_type)

        output = torch.stack(node_feat_all_edge_type, dim=1)
		
		# 得到邻居消息传递:message passing and aggregate with neighbors according to adj.
		# 用etypes个GCN(对应的W参数保存于self.edgeType_linears)处理不同关系下的node-feature
        output = torch.matmul(adj_tensor, output) # adj_tensor(32,4,9,9) * output(32,4,9,128)
        # using SUM to aggregate the node-embeddings under different edge types:
        out_sum = torch.sum(input=output, dim=1) # out_sum(32,9,128) 聚合的邻居消息

        # aggregate with the node's self feature:		
        node_self_annotation = self.linear_2(node_annotations) # node_self_annotation(32,9,128) 聚合的邻居消息
        # 将传来的消息与自身特征聚合
        output = out_sum + node_self_annotation # output(32,9,128)
        output = self.activation(output) if self.activation is not None else output
        output = self.dropout(output)
        return output
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值