torch中(required_grad_=True),根据mask替换掉部分值

举个例子

a = torch.randn(2, 3, 4)
a.requires_grad=True
a
tensor([[[ 1.2255,  1.5870, -0.5485, -0.4809],
         [-0.3167, -0.2933, -0.0604,  0.3498],
         [ 0.1436,  0.3083,  1.6776, -1.1144]],
        [[-1.5001,  0.7174,  0.2585,  0.2669],
         [-0.1319, -0.8247,  0.1929, -0.6142],
         [ 1.1407,  2.2324, -1.3897, -0.2413]]], requires_grad=True)

假设mask为:

mask
tensor([[False,  True,  True],
        [False,  True,  True]])

(1)如果假设b为:

b = torch.arange(16).view(2, 2, 4).float()
b.requires_grad = True
b
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.]],
        [[ 8.,  9., 10., 11.],
         [12., 13., 14., 15.]]], requires_grad=True)

直接赋值会导致出错:

a[mask] = b
Traceback (most recent call last):
  File "<input>", line 1, in <module>
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

这是因为叶子节点不能进行原地替换操作。
(2)介绍一个函数tgt.index_put_(indices, value):根据indices把tgt中的值替换为value。
另外,b的size也不能是[2, 2, 4]了,应该是[4, 4]。可以根据下面这个方法确定b的size。

a[mask].size()
torch.Size([4, 4])

则:

b = torch.arange(16).view(4, 4).float()
b.requires_grad = True
b
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]], requires_grad=True)
c = a.index_put((mask,), b)
c
tensor([[[ 1.2255,  1.5870, -0.5485, -0.4809],
         [ 0.0000,  1.0000,  2.0000,  3.0000],
         [ 4.0000,  5.0000,  6.0000,  7.0000]],
        [[-1.5001,  0.7174,  0.2585,  0.2669],
         [ 8.0000,  9.0000, 10.0000, 11.0000],
         [12.0000, 13.0000, 14.0000, 15.0000]]], grad_fn=<IndexPutBackward0>)

请注意:

  1. mask必须转为tuple,(mask,);不能使用tuple(mask),这两个的顺序不一样
(mask,)
(tensor([[False,  True,  True],
        [False,  True,  True]]),)
tuple(mask)
(tensor([False,  True,  True]), tensor([False,  True,  True]))

2.这个out-place操作,不是in-place操作,需要赋值为c才能进行前向、后向传播和梯度更新。

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值