使用Pytorch也一年多了,记录一下自己遇到的各种问题。
各种踩坑记录
loss.backward()报错
一种非常常见的错误,在网络前向传播时没问题,但是当loss.backward()时会报错
导致这个错误的原因非常多样
- in-place操作导致,具体的in-place操作有很多,例如squeeze_(), x[:]=y, 等等,网上相关资料很多,不细写
- loss需要是一个标量,如果是向量的话,需要loss = loss.sum();
- 有些向量不需要传递梯度时,使用x.detach()截断梯度传递;
- 一个非常有用的命令,可以加在loss.backward()外,方便定位具体哪一行导致的报错:
with torch.autograd.set_detect_anomaly(True):
loss.backward(