1.首先检查的原因是训练集中出现脏数据,脏数据的出现导致我的logits计算出了0,0传给 log(x|x=0) 导致为∞, 即nan。之所以会这样,是因为我的实验是实际业务上的真实数据。所以需要一条一条检测数据是否为脏数据。
2.其次才是使用如下解决方案
(1)、数据归一化(减均值,除方差,或者加入normalization,例如BN、L2 norm等);
(2)、更换参数初始化方法(对于CNN,一般用xavier或者msra的初始化方法);
(3)、减小学习率、减小batch size;
(4)、加入gradient clipping;