detach函数表示将tensor变量的梯度删除,从而实现减少内存的目的
举个简单的例子:
# 自动求梯度
import torch
x = torch.tensor(1.,requires_grad = True)
w = torch.tensor(2.,requires_grad = True)
b = torch.tensor(3.,requires_grad = True)
y = w*x+b; z = y**2; z.backward();
print('w.grad:',w.grad,'x.grad:',x.grad)
输出结果为:
若对其中的x变量进行detach,即
# 自动求梯度
import torch
x = torch.tensor(1.,requires_grad = True)
w = torch.tensor(2.,requires_grad = True)
b = torch.tensor(3.,requires_grad = True)
y = w*x+b; z = y**2; z.backward(); x = x.detach();
print('w.grad:',w.grad,'x.grad:',x.grad)
则输出结果中x.grad为None