最近在控制台输出一下loss的时候用到了这俩函数,在这里记录一下。
总体上来说tensor.detach()
是为了解决tensor.data()
的安全性提出的。tensor.detach()
相对较为安全。因为当通过.detach()
得到的tensor
间接修改原来的tensor
后继续在计算图中使用时会报错,但是通过.data()
得到的tensor
间接修改原tensor
后继续在计算图中使用就会被忽略被修改的过程,例如:
>>> a = torch.tensor([1,2,3.], requires_grad =True)
>>> out = a.sigmoid() #原本是out这个tensor
>>> c = out.data #这里是data()
>>> c.zero_()
tensor([ 0., 0., 0.])
>>> out # 在这里out的数值被c.zero_()修改也就是提到的原来的out这个tensor被间接修改了
tensor([ 0., 0., 0.])
>>> out.sum().backward() # 修改后的out参与反向传播
>>> a.grad # 这个结果很严重的错误,因为out已经改变了
tensor([ 0., 0., 0.])
但是如果换成.detach()
>>> a = torch.tensor([1,2,3.], requires_grad =True)
>>> out = a.sigmoid() #原本是out这个tensor
>>> c = out.detach() #这里是detach()
>>> c.zero_()
tensor([ 0., 0., 0.])
>>> out # out的值被c.zero_()修改
tensor([ 0., 0., 0.])
>>> out.sum().backward() # out参与反向传播,但是已经被c.zero_()了,结果报错
RuntimeError: one of the variables needed for gradient
computation has been modified by an