PyTorch 中的“with torch no_grad”有什么作用?

torch.no_grad()用于在深度学习计算中临时禁用张量的梯度追踪,使得在该上下文中执行的运算不会记录在计算图中,从而节省内存和提高效率。在with块内,张量的requires_grad属性设为False,退出该块后,如果张量原本需要梯度,它会重新连接到计算图。
摘要由CSDN通过智能技术生成

“with ”torch.no_grad()的使用就像一个循环,其中循环内的每个张量都将requires_grad设置为False。这意味着当前与当前计算图相连的任何具有梯度的张量现在都与当前图分离。我们不再能够计算关于这个张量的梯度。

张量从当前图中分离,直到它在循环内。一旦它离开循环,如果张量是用梯度定义的,它就会再次附加到当前图。

让我们举几个例子来更好地理解它是如何工作的。

示例 1

在这个例子中,我们用requires_grad = true创建了一个张量 x 。接下来,我们定义这个张量 x 的函数 y 并将该函数放入 with循环中。现在 x 在循环内,所以它的requires_grad被设置为Falsetorch.no_grad()

在循环中,无法计算 y 相对于 x 的梯度。所以,y.requires_grad返回False

# import torch library
import torch

# define a torch tensor
x = torch.tensor(2., requires_grad = True)
print("x:", x)

# define a function y
with torch.no_grad():
   y = x ** 2
print("y:", y)

# check gradient for Y
print("y.requires_grad:", y.requires_grad)

"""
输出结果
x: tensor(2., requires_grad=True)
y: tensor(4.)
y.requires_grad: False
"""
 

示例 2

在这个例子中,我们在循环之外定义了函数z。所以,z.requires_grad返回True

# import torch library
import torch

# define three tensors
x = torch.tensor(2., requires_grad = False)
w = torch.tensor(3., requires_grad = True)
b = torch.tensor(1., requires_grad = True)

print("x:", x)
print("w:", w)
print("b:", b)

# define a function y
y = w * x + b
print("y:", y)

# define a function z
with torch.no_grad():
   z = w * x + b

print("z:", z)

# check if requires grad is true or not
print("y.requires_grad:", y.requires_grad)
print("z.requires_grad:", z.requires_grad)

"""
输出结果
x: tensor(2.)
w: tensor(3., requires_grad=True)
b: tensor(1., requires_grad=True)
y: tensor(7., grad_fn=<AddBackward0>)
z: tensor(7.)
y.requires_grad: True
z.requires_grad: False
"""
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值