pytorch自定义用于稀疏矩阵乘法的反向传播函数

class SpecialSpmmFunction(torch.autograd.Function):
    """Special function for only sparse region backpropataion layer."""
    @staticmethod
    def forward(ctx, indices, values, shape, b):
        assert indices.requires_grad == False
        a = torch.sparse_coo_tensor(indices, values, shape)
        ctx.save_for_backward(a, b)
        ctx.N = shape[0]
        return torch.matmul(a, b)

    @staticmethod
    def backward(ctx, grad_output):
        a, b = ctx.saved_tensors
        grad_values = grad_b = None
        if ctx.needs_input_grad[1]:
            grad_a_dense = grad_output.matmul(b.t())
            edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
            grad_values = grad_a_dense.view(-1)[edge_idx]
        if ctx.needs_input_grad[3]:
            grad_b = a.t().matmul(grad_output)
        return None, grad_values, None, grad_b


class SpecialSpmm(nn.Module):
    def forward(self, indices, values, shape, b):
        return SpecialSpmmFunction.apply(indices, values, shape, b)

一、 整体代码如上,分为两个部分:1、一个继承了torch.autograd.function的类,类中有两个静态函数,forward和backward ; 2 、一个实现了上述类的spmm类,用于调用并且返回梯度结果,两个部分结合使用的效果相当于一个函数。

这个函数的作用是实现稀疏矩阵乘法的前向传播和反向传播。它接收四个输入参数:indices(稀疏矩阵的非零元素的位置信息),values(稀疏矩阵的非零元素的值),shape(矩阵的形状)和b(矩阵乘法的运算数)。

二、在前向传播中,它首先使用indices和values创建一个稀疏COO张量a,然后将a和b进行矩阵乘法运算,返回结果。这是前向传播过程,也就是计算稀疏矩阵和矩阵b相乘的过程。

三、接下来进行反向传播backward,它接收一个梯度张量grad_output作为输入。根据链式法则,它计算并返回了对输入的梯度。具体而言,它计算出grad_a和grad_b,其中grad_a是对values的梯度,grad_b是对b的梯度。但是和b不一样,由于a是一个稀疏矩阵,grad_a的计算需要特殊处理。它首先计算grad_a_dense,即将grad_output与b的转置矩阵相乘得到的密集梯度矩阵。然后,它通过索引操作将grad_a_dense中的值提取出来,索引edge_index是根据a的indices计算得到的。最后,grad_b的计算通过将a的转置矩阵与grad_output相乘得到。

四、关于edge_index的提取:

a._indices() 返回稀疏张量 a 的索引,其中第一行表示非零元素在行的位置,第二行表示非零元素在列的位置。a._indices()[0, :] 返回索引的第一行,即非零元素在行的位置。

接下来,这行代码计算了 edge_idx,它由两部分组成:

  • 第一部分是 a._indices()[0, :] * ctx.N,它将非零元素在行的位置乘以稀疏矩阵的总行数 ctx.N。这可以将行的位置转换为全局索引。
  • 第二部分是 a._indices()[1, :],它表示非零元素在列的位置。

将两部分相加,得到了 edge_idx,它是一个一维张量,包含了在 grad_a_dense 中需要提取梯度的位置索引。这样,通过 grad_a_dense.view(-1)[edge_idx] 就可以提取对应位置的梯度值。

总结: 这个函数实现了稀疏区域的矩阵乘法和梯度计算。它通过使用稀疏张量和特殊的索引操作,在需要的时候只计算稀疏矩阵的非零元素,从而提高了计算效率。

代码地址 : https://github.com/Diego999/pyGAT

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值