PyTorch中使用inplace operation导致无法梯度反向传播
问题
- 今天写代码的是时候想要实现这样一个功能,假设有一个输入的时间序列
input1
,input1.shape = [T, D]
,然后我想对其中的input1[start:end, :]
进行扰动,代码如下
import torch
perturbation1 = torch.nn.Parameter(torch.rand(shape))
input1 = torch.rand(shape)
# perturbed_input1 and input1 share common memory
perturbed_input1 = input1[:]
perturbed_input1[0:2,:] = input1[0:2] * (1 + perturbation1[0:2])
loss1