torch.nn.BCEWithLogitsLoss相当于sigmoid+torch.nn.BCELoss。代码示例如下,
import torch
import torch.nn as nn
BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
BCELoss = nn.BCELoss()
x = torch.randn((1,))
y = torch.FloatTensor([1])
Loss_BCEWithLogits = BCEWithLogitsLoss(x, y)
Loss_BCE = BCELoss(torch.sigmoid(x), y)
print("BCEWithLogitsLoss:", Loss_BCEWithLogits)
print("BCELoss:", Loss_BCE)
"""
BCEWithLogitsLoss: tensor(0.2138)
BCELoss: tensor(0.2138)
"""