原因一:data underflow
原因: 在使用log函数时出现 data underflow
解决方法:增加一个bias
# original
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
g_wh = torch.log(g_wh) / variances[1]
# current
eps = 1e-5
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
g_wh = torch.log(g_wh+eps) / variances[1]
https://github.com/amdegroot/ssd.pytorch/issues/162
https://www.zhihu.com/question/49346370
原因二:CTC length不符合要求
原因:input length < target CTC length
ctc loss要求input length > target CTC length, 注意 target CTC length 不同于target length,两者区别如下:
def ctc_len(label):
add_len = 0
label_len = len(label)
for i in range(label_len - 1):
if label[i] == label[i + 1]:
add_len += 1 # 这里+1是因为ctc会在重复字符之间填充blank
return label_len + add_len
target_length = len(label)
target_ctc_len = ctc_len(label)
https://github.com/xingchensong/ASR-Wavnet/blob/master/datafeeder.py#L114