torch.autograd.Function

代码摘自LSS:

class QuickCumsum(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, geom_feats, ranks):
        x = x.cumsum(0)
        kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
        kept[:-1] = (ranks[1:] != ranks[:-1])

        x, geom_feats = x[kept], geom_feats[kept]
        x = torch.cat((x[:1], x[1:] - x[:-1]))

        # save kept for backward
        ctx.save_for_backward(kept)

        # no gradient for geom_feats
        ctx.mark_non_differentiable(geom_feats)

        return x, geom_feats

    @staticmethod
    def backward(ctx, gradx, gradgeom):
        kept, = ctx.saved_tensors
        back = torch.cumsum(kept, 0)
        back[kept] -= 1

        val = gradx[back]

        return val, None, None

给出这样一个类,x, geom_feats = QuickCumsum.apply(x, geom_feats, ranks),

解释.apply用法和apply括号中的参数:

在PyTorch中,torch.autograd.Function允许用户自定义自动微分操作——即定义自己的前向和反向传播函数。QuickCumsum类是torch.autograd.Function的一个子类,用于实现一个特定的操作,这个操作在前向传播中计算输入张量的累积和(cumulative sum),并根据ranks张量过滤某些元素,然后在反向传播中计算这个操作的梯度。

.apply方法是torch.autograd.Function的一个内置方法,它允许你调用该函数。当你使用.apply时,实际上是在调用类的forward方法,并且自动处理反向传播(通过backward方法)。这意味着,当你调用QuickCumsum.apply(x, geom_feats, ranks)时,实际上是执行以下步骤:

  • x, geom_feats, ranks是传入.apply方法的参数。
  • x是要进行累积和操作的张量。
  • geom_feats是与x相关联的几何特征张量,根据实现看,在前向传播中会被过滤和修改,但被标记为不需要梯度(即在反向传播中不计算其梯度)。
  • ranks是一个张量,其值用于决定如何过滤xgeom_feats中的元素。特别是,它决定了哪些元素在计算累积和之后应该被保留。

forward方法中的具体步骤如下:

  1. 首先,计算x的累积和。
  2. 接着,创建一个布尔张量kept,用于决定哪些元素在计算累积和之后应该被保留。这是通过比较ranks张量相邻元素的值来决定的。
  3. 然后,根据kept张量过滤xgeom_feats
  4. 最后,调整过滤后的x,使其表示原始序列中的变化量,而不是累积量。

backward方法中,实现了自定义的梯度传递逻辑,这是基于前向传播中保留下来的kept张量。这确保了梯度正确地传回通过过滤操作保留下来的元素,而不是原始输入张量中的所有元素。

总结,.apply方法使得QuickCumsum能够作为一个自定义的自动微分操作被调用,其中x, geom_feats, ranks是该操作的输入参数,而这个类的实现定义了这些输入是如何被处理以及如何计算它们的梯度的。

  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值