代码摘自LSS: class QuickCumsum(torch.autograd.Function): @staticmethod def forward(ctx, x, geom_feats, ranks): x = x.cumsum(0) kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool) kept[:-1] = (ranks[1:] != ranks[:-1]) x, geom_feats = x[kept], geom_feats[kept] x = torch.cat((x[:1], x[1:] - x[:-1])) # save kept for backward ctx.save_for_backward(kept) # no gradient for geom_feats ctx.mark_non_differentiable(geom_feats) return x, geom_feats @staticmethod def backward(ctx, gradx, gradgeom): kept, = ctx.saved_tensors back = torch.cumsum(kept, 0) back[kept] -= 1 val = gradx[back] return val, None, None
给出这样一个类,x, geom_feats = QuickCumsum.apply(x, geom_feats, ranks),
解释.apply用法和apply括号中的参数:
在PyTorch中,torch.autograd.Function
允许用户自定义自动微分操作——即定义自己的前向和反向传播函数。QuickCumsum
类是torch.autograd.Function
的一个子类,用于实现一个特定的操作,这个操作在前向传播中计算输入张量的累积和(cumulative sum),并根据ranks
张量过滤某些元素,然后在反向传播中计算这个操作的梯度。
.apply
方法是torch.autograd.Function
的一个内置方法,它允许你调用该函数。当你使用.apply
时,实际上是在调用类的forward
方法,并且自动处理反向传播(通过backward
方法)。这意味着,当你调用QuickCumsum.apply(x, geom_feats, ranks)
时,实际上是执行以下步骤:
x, geom_feats, ranks
是传入.apply
方法的参数。x
是要进行累积和操作的张量。geom_feats
是与x
相关联的几何特征张量,根据实现看,在前向传播中会被过滤和修改,但被标记为不需要梯度(即在反向传播中不计算其梯度)。ranks
是一个张量,其值用于决定如何过滤x
和geom_feats
中的元素。特别是,它决定了哪些元素在计算累积和之后应该被保留。
在forward
方法中的具体步骤如下:
- 首先,计算
x
的累积和。 - 接着,创建一个布尔张量
kept
,用于决定哪些元素在计算累积和之后应该被保留。这是通过比较ranks
张量相邻元素的值来决定的。 - 然后,根据
kept
张量过滤x
和geom_feats
。 - 最后,调整过滤后的
x
,使其表示原始序列中的变化量,而不是累积量。
在backward
方法中,实现了自定义的梯度传递逻辑,这是基于前向传播中保留下来的kept
张量。这确保了梯度正确地传回通过过滤操作保留下来的元素,而不是原始输入张量中的所有元素。
总结,.apply
方法使得QuickCumsum
能够作为一个自定义的自动微分操作被调用,其中x, geom_feats, ranks
是该操作的输入参数,而这个类的实现定义了这些输入是如何被处理以及如何计算它们的梯度的。