Pytorch内部中optim和loss是如何交互的?
loss.backward()获得所有parameter的gradient。
然后optimizer存了这些parameter的指针,step()根据这些parameter的gradient对parameter的值进行更新。
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
这一句代码中optimizer获取到了所有parameters的引用,每个parameter都包含梯度(gradient),optimizer可以把梯度应用上去更新parameter。
loss = loss_func(prediction, y)
loss.backward()
而梯度由这2句产生,第一句对prediction和y之间进行比对(熵或者其他loss function),产生最初的梯度,然后经由第二句反向传播到整个网络的所有链路和节点。节点与节点之间有联系,因此可以反向链式传播梯度。
optimizer.step()
这一句是apply所有的梯度以更新parameter的值,这里不需要传入梯度,因为梯度的引用已经在其构造函数中传入的parameters中包含。