经典的LeNet,在PyTorch/examples/mnist 实现中有个小问题,在这里和大家分享一下。
是一个计算generalization error的问题。
计算generalization error时,原代码有一行我一直不理解,test_loss /= len(test_loader) # loss function already averages over batch size
。试了一下输出,就发现len(test_loader)
指的是进行一次data pass,会将所有的数据分割成多少份的mini-batch。然后配合原来的一行代码,在每次mini-batch计算loss的时候,test_loss += F.nll_loss(output, target).data[0]
,就可以推出原代码在求解generalization error的逻辑是这样的:
- 将一个data pass分成几个mini-batch
- 每一个mini-batch,
F.nll_loss(output, target).data[0]
的loss value并不是整个mini-batch的loss,而是average loss,有一个默认size_average=True
的参数(后面会用到) - 进行一次data pass之后,就可以将每一个mini-batch average loss求和
- 将loss之和再除以mini-batch的数量,就得到最后的data point average loss
所以到这里就能知道,这里有一个隐藏的bug:这里假设了我每一个mini-batch size是一样的,所以才能用这样求平均的方式。但实际上,最后一个mini-batch是很难正好“满上”的。
更为精确求解loss的方法是,每一个mini-batch loss不算平均,而直接求和。最后除以所有data point的个数。大概代码如下:
for each mini-batch:
...
test_loss += F.nll_loss(output, target, size_average=False).data[0]
...
...
test_loss /= len(test_loader.dataset)
...
第一次给pub repo做pull request,感觉挺不错。PyTorch还有很多坑没填,大家感兴趣的可以慢慢填。我看到很多人直接把原bug照搬到自己的repo,也不知道有没有发现这个问题。其实只要深究test_loss /= len(test_loader)
这一行就能发现。
附:因为我pull request已经被merge到code base里,感兴趣的朋友可以移步到github. 类似的问题在mnist_hogwild
也有,成功merge github.