与TensorFlow等框架使用的“静态计算图”不同。在PyTorch中,计算图是在运行时动态构建的,这意味着在每次前向传播时,都会根据当前的操作创建一个新的计算图。
下面举一个简单的例子来帮助理解:
import torch
# 创建一个张量并设置requires_grad=True来跟踪其计算历史
x = torch.ones(2, 2, requires_grad=True)
print(x)
'''tensor([[1., 1.],
[1., 1.]], requires_grad=True)'''
# 对张量进行操作
y = x + 2
print(y)
'''tensor([[3., 3.],
[3., 3.]], grad_fn=<AddBackward0>)'''
# y是操作的结果,所以它有grad_fn属性
print(y.grad_fn)
'''<AddBackward0 object at 0x0000017595769720>'''
z = y * y * 3
out = z.mean()
print(z, out)
'''tensor([[27., 27.],
[27., 27.]], grad_fn=<MulBackward0>)
tensor(27., grad_fn=<MeanBackward0>)'''
在这个例子中,我们首先创建了一个形状为(2, 2)的张量x
,并设置requires_grad=True
以便跟踪其计算历史。然后我们对x
进行了一系列的操作,最后计算了结果的均值。
在这个过程中,PyTorch会自动地为我们构建一个计算图。这个计算图记录了从输入x
到输出out
的所有操作。当对out
进行反向传播时,PyTorch会利用这个计算图来计算梯度。
例如,如果我们想计算输出out
关于输入x
的梯度,可以这样做:
out.backward()
print(x.grad)
'''tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])'''
这段代码会计算out
关于x
的梯度,并将结果存储在x.grad
中。注意,只有requires_grad=True
的张量才会有.grad
属性。在这个例子中,x
的.grad
属性现在包含了out
关于x
的梯度。
总的来说,PyTorch的计算图是动态构建的,它会在你执行操作时自动记录计算历史。这使得PyTorch的代码更加直观和易于调试,因为你不需要事先定义计算图的结构。