PyTorch —— LeNet实现中的bug以及由此的小想法

经典的LeNet,在PyTorch/examples/mnist 实现中有个小问题,在这里和大家分享一下。

是一个计算generalization error的问题。

计算generalization error时,原代码有一行我一直不理解,test_loss /= len(test_loader) # loss function already averages over batch size。试了一下输出,就发现len(test_loader)指的是进行一次data pass,会将所有的数据分割成多少份的mini-batch。然后配合原来的一行代码,在每次mini-batch计算loss的时候,test_loss += F.nll_loss(output, target).data[0],就可以推出原代码在求解generalization error的逻辑是这样的:

  1. 将一个data pass分成几个mini-batch
  2. 每一个mini-batch,F.nll_loss(output, target).data[0]的loss value并不是整个mini-batch的loss,而是average loss,有一个默认size_average=True的参数(后面会用到)
  3. 进行一次data pass之后,就可以将每一个mini-batch average loss求和
  4. 将loss之和再除以mini-batch的数量,就得到最后的data point average loss

所以到这里就能知道,这里有一个隐藏的bug:这里假设了我每一个mini-batch size是一样的,所以才能用这样求平均的方式。但实际上,最后一个mini-batch是很难正好“满上”的。

更为精确求解loss的方法是,每一个mini-batch loss不算平均,而直接求和。最后除以所有data point的个数。大概代码如下:

for each mini-batch:
    ...
    test_loss += F.nll_loss(output, target, size_average=False).data[0]
    ...

...
test_loss /= len(test_loader.dataset)
...

第一次给pub repo做pull request,感觉挺不错。PyTorch还有很多坑没填,大家感兴趣的可以慢慢填。我看到很多人直接把原bug照搬到自己的repo,也不知道有没有发现这个问题。其实只要深究test_loss /= len(test_loader)这一行就能发现。

附:因为我pull request已经被merge到code base里,感兴趣的朋友可以移步到github. 类似的问题在mnist_hogwild也有,成功merge github.

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值