torch.no_grad():
作用: 上下文管理器, 在执行代码块的时候,不用记录这些更新操作的计算图和梯度信息, 减少 内存消耗和提高代码的效率。
def sgd(params, lr, batch_size):
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()
实现了随机梯度下降SGD优化算法的更新步骤,用于更新模型参数以最小化损失函数。具体来说:
1、 便利params中所有的参数,其中,params是一个包含需要更新的模型参数的列表
2、 对于每个参数 param, 使用梯度下降更新公式来更新值,具体地, 将参数的值减去一个学习率 “lr" 乘以 该参数的梯度”param.grad“ 除以批量大小 ”batch_size"
3、 在更新参数值之后, 使用 param.grad.zero_() 将参数的梯度清零, 以便在下一次计算梯度之前不受之前梯度的影响。
由于使用了‘torch.no_grad()’ 上下文管理器, 因此在执行此代码块时 , pytorch不会计算相关信息。
举个例子,便于理解:
假设你正在学习如何骑自行车,而你的目标是能够顺利地骑上一段长距离。在这个过程中,你需要调整自行车的各个部分,例如车座的高度、方向盘的位置等等,以便它们适合你的身高和骑车习惯。
在这个例子中,自行车的各个部分就对应着神经网络的参数,而调整它们的过程就对应着优化算法的更新步骤。下面是这个例子中的一些细节:
…遍历 params 中的所有参数:就像你需要调整自行车的各个部分一样,优化算法需要遍历神经网络中的所有参数,以便调整它们的值。
…使用梯度下降更新公式来更新参数的值:就像你需要不断地调整车座的高度和方向盘的位置一样,优化算法需要使用梯度下降的公式来更新神经网络中的参数。具体来说,该公式根据当前的梯度信息来计算每个参数应该如何更新。
. .将参数的梯度清零:就像你不希望之前的车座高度和方向盘位置影响到当前的调整一样,优化算法需要在更新完参数之后将其梯度清零,以便下一次计算梯度不受之前的影响。
不记录更新操作的计算图和梯度信息:就像你不需要记录每个调整步骤的细节一样,优化算法使用 torch.no_grad()
…上下文管理器来避免记录更新操作的计算图和梯度信息,以减少内存消耗和提高代码效率。