class SpecialSpmmFunction(torch.autograd.Function):
"""Special function for only sparse region backpropataion layer."""
@staticmethod
def forward(ctx, indices, values, shape, b):
assert indices.requires_grad == False
a = torch.sparse_coo_tensor(indices, values, shape)
ctx.save_for_backward(a, b)
ctx.N = shape[0]
return torch.matmul(a, b)
@staticmethod
def backward(ctx, grad_output):
a, b = ctx.saved_tensors
grad_values = grad_b = None
if ctx.needs_input_grad[1]:
grad_a_dense = grad_output.matmul(b.t())
edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
grad_values = grad_a_dense.view(-1)[edge_idx]
if ctx.needs_input_grad[3]:
grad_b = a.t().matmul(grad_output)
return None, grad_values, None, grad_b
class SpecialSpmm(nn.Module):
def forward(self, indices, values, shape, b):
return SpecialSpmmFunction.apply(indices, values, shape, b)
一、 整体代码如上,分为两个部分:1、一个继承了torch.autograd.function的类,类中有两个静态函数,forward和backward ; 2 、一个实现了上述类的spmm类,用于调用并且返回梯度结果,两个部分结合使用的效果相当于一个函数。
这个函数的作用是实现稀疏矩阵乘法的前向传播和反向传播。它接收四个输入参数:indices(稀疏矩阵的非零元素的位置信息),values(稀疏矩阵的非零元素的值),shape(矩阵的形状)和b(矩阵乘法的运算数)。
二、在前向传播中,它首先使用indices和values创建一个稀疏COO张量a,然后将a和b进行矩阵乘法运算,返回结果。这是前向传播过程,也就是计算稀疏矩阵和矩阵b相乘的过程。
三、接下来进行反向传播backward,它接收一个梯度张量grad_output作为输入。根据链式法则,它计算并返回了对输入的梯度。具体而言,它计算出grad_a和grad_b,其中grad_a是对values的梯度,grad_b是对b的梯度。但是和b不一样,由于a是一个稀疏矩阵,grad_a的计算需要特殊处理。它首先计算grad_a_dense,即将grad_output与b的转置矩阵相乘得到的密集梯度矩阵。然后,它通过索引操作将grad_a_dense中的值提取出来,索引edge_index是根据a的indices计算得到的。最后,grad_b的计算通过将a的转置矩阵与grad_output相乘得到。
四、关于edge_index的提取:
a._indices()
返回稀疏张量 a
的索引,其中第一行表示非零元素在行的位置,第二行表示非零元素在列的位置。a._indices()[0, :]
返回索引的第一行,即非零元素在行的位置。
接下来,这行代码计算了 edge_idx
,它由两部分组成:
- 第一部分是
a._indices()[0, :] * ctx.N
,它将非零元素在行的位置乘以稀疏矩阵的总行数ctx.N
。这可以将行的位置转换为全局索引。 - 第二部分是
a._indices()[1, :]
,它表示非零元素在列的位置。
将两部分相加,得到了 edge_idx
,它是一个一维张量,包含了在 grad_a_dense
中需要提取梯度的位置索引。这样,通过 grad_a_dense.view(-1)[edge_idx]
就可以提取对应位置的梯度值。
总结: 这个函数实现了稀疏区域的矩阵乘法和梯度计算。它通过使用稀疏张量和特殊的索引操作,在需要的时候只计算稀疏矩阵的非零元素,从而提高了计算效率。