import torch
import numpy as np
pred = np.array([[-0.4089, -1.2471, 0.5907],
[-0.4897, -0.8267, -0.7349],
[0.5241, -0.1246, -0.4751]])
label = np.array([[0, 1, 1],
[0, 0, 1],
[1, 0, 1]])
pred = torch.from_numpy(pred).float()
label = torch.from_numpy(label).float()
## 通过BCEWithLogitsLoss直接计算输入值(pick)
crition1 = torch.nn.BCEWithLogitsLoss()
loss1 = crition1(pred, label)
print(loss1)
crition2 = torch.nn.MultiLabelSoftMarginLoss()
loss2 = crition2(pred, label)
print(loss2)
## 通过BCELoss计算sigmoid处理后的值
crition3 = torch.nn.BCELoss()
loss3 = crition3(torch.sigmoid(pred), label)
print(loss3)
# 三者一致
#tensor(0.7193)
#tensor(0.7193)
#tensor(0.7193)
# multilabel_soft_margin_loss 底层代码
def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
r"""multilabel_soft_margin_loss(input, target, weight=None, size_average=None) -> Tensor
See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details.
"""
if not torch.jit.is_scripting():
tens_ops = (input, target)
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
return handle_torch_function(
multilabel_soft_margin_loss, tens_ops, input, target, weight=weight,
size_average=size_average, reduce=reduce, reduction=reduction)
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)
loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input))
if weight is not None:
loss = loss * weight
loss = loss.sum(dim=1) / input.size(1) # only return N loss values
if reduction == 'none':
ret = loss
elif reduction == 'mean':
ret = loss.mean()
elif reduction == 'sum':
ret = loss.sum()
else:
ret = input
raise ValueError(reduction + " is not valid")
return ret