用Pytorch 有一段时间,因为一直找不到合适的教材,只能边看边做。看代码一直看得云里雾里,很多代码里的解释都含糊不清,一直没有办法流畅地写代码,记录一些坑,也希望可以让国内的开发者在用最好的开发工具上少走一些弯路。
国内大多数的开发者都喜欢搬一些论文,公式推导。但细节的坑不填平,和有稍微合适的解释,看完一堆论文仍然不知道怎么落地。列举一些Pytorch 坑人的地方:
1.ver 0.1是个不稳定的版本,训练速度很慢,升级到0.3以后运算效率明显提高。
2. Embedding 层的内部bug. 据说tensorflow 和其他框架也有:
https://discuss.pytorch.org/t/pytorch-0-2-nan-for-embedding/10781
如果你是做NLP 或者需要用到Embedding层,会出现算着算着Embedding里的字典矩阵返回NaN的现象。这个问题足足困扰了我近20多天导致研究一直停滞不前。国内论坛的解释大多要求增加norm初始化词向量,或者减小batch_size,还有解释要减小learning rate. 某种程度上这些方案都是因为网络的搭建有缺陷造成的。 但大多数开发者其实都不会那么没理论基础忽略这些重要的细节。
解决方案: 很奇葩地升级的pytorch到最新的版本(ver 0.4)即可。