之前的:先对logit进行sigmoid,再使用pytorch的BCE。
修改:pytorch 和 tensorflow均有 内置的sigmoidBCE的函数。
torch.nn.BCEWithLogitsLoss
import torch
target = torch.ones([5, 4], dtype=torch.float32) # 4 classes, batch size = 10
output = torch.full([5, 4], 1.5) # A prediction (logit)
pos_weight = torch.ones([4]) # All weights are equal to 1
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target) # -log(sigmoid(1.5)) =tensor(0.2014)
此版本比使用普通 Sigmoid 后跟 BCELoss 在数值上更稳定,因为通过将操作组合到一个层中, 我们利用对数总和 exp 技巧来实现数值稳定性。
注:target和output都要是float型,在onehot后变成long型要转换。