pytroch中ctx和self的区别

welcome to my blog

阅读某个pytorch模型源代码时碰见的ctx参数, 查阅了资料大概总结一下
  1. ctx是context的缩写, 翻译成"上下文; 环境"
  2. ctx专门用在静态方法中
  3. self指的是实例对象; 而ctx用在静态方法中, 调用的时候不需要实例化对象, 直接通过类名就可以调用, 所以self在静态方法中没有意义
  4. 自定义的forward()方法和backward()方法的第一个参数必须是ctx; ctx可以保存forward()中的变量,以便在backward()中继续使用, 下一条是具体的示例
  5. ctx.save_for_backward(a, b)能够保存forward()静态方法中的张量, 从而可以在backward()静态方法中调用, 具体地, 下面地代码通过a, b = ctx.saved_tensors重新得到a和b
  6. ctx.needs_input_grad是一个元组, 元素是True或者False, 表示forward()中对应的输入是否需要求导, 比如ctx.needs_input_grad[0]指的是下面forwad()代码中indices是否需要求导
class SpecialSpmmFunction(torch.autograd.Function):
    """
    Special function for only sparse region backpropataion layer.
    """
    # 自定义前向传播过程
    @staticmethod
    def forward(ctx, indices, values, shape, b):
        assert indices.requires_grad == False
        a = torch.sparse_coo_tensor(indices, values, shape)
        ctx.save_for_backward(a, b)
        ctx.N = shape[0]
        return torch.matmul(a, b)
    # 自定义反向传播过程
    @staticmethod
    def backward(ctx, grad_output):
        a, b = ctx.saved_tensors
        grad_values = grad_b = None
        if ctx.needs_input_grad[1]:
            grad_a_dense = grad_output.matmul(b.t())
            edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
            grad_values = grad_a_dense.view(-1)[edge_idx]

        if ctx.needs_input_grad[3]:
            grad_b = a.t().matmul(grad_output)
        return None, grad_values, None, grad_b
ctx还能调用很多方法, pytorch1.3.1源码中竟然说"no doc", 没有相关的文档…
class _FunctionBase(object):
    # no doc
    @classmethod
    def apply(cls, *args, **kwargs): # real signature unknown
        pass

    def register_hook(self, *args, **kwargs): # real signature unknown
        pass

    def _do_backward(self, *args, **kwargs): # real signature unknown
        pass

    def _do_forward(self, *args, **kwargs): # real signature unknown
        pass

    def _register_hook_dict(self, *args, **kwargs): # real signature unknown
        pass

    def __init__(self, *args, **kwargs): # real signature unknown
        pass

    @staticmethod # known case of __new__
    def __new__(*args, **kwargs): # real signature unknown
        """ Create and return a new object.  See help(type) for accurate signature. """
        pass

    dirty_tensors = property(lambda self: object(), lambda self, v: None, lambda self: None)  # default

    metadata = property(lambda self: object(), lambda self, v: None, lambda self: None)  # default

    needs_input_grad = property(lambda self: object(), lambda self, v: None, lambda self: None)  # default

    next_functions = property(lambda self: object(), lambda self, v: None, lambda self: None)  # default

    non_differentiable = property(lambda self: object(), lambda self, v: None, lambda self: None)  # default

    requires_grad = property(lambda self: object(), lambda self, v: None, lambda self: None)  # default

    saved_tensors = property(lambda self: object(), lambda self, v: None, lambda self: None)  # default

    saved_variables = property(lambda self: object(), lambda self, v: None, lambda self: None)  # default

    to_save = property(lambda self: object(), lambda self, v: None, lambda self: None)  # default
  • 38
    点赞
  • 71
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值