MultiLabelSoftMarginLoss
不知道pytorch为什么起这个名字,看loss计算公式,并没有涉及到margin,有可能后面会实现。按照我的理解其实就是多标签交叉熵损失函数,验证之后也和BCEWithLogitsLoss
的结果输出一致,使用的torch版本为1.5.0
https://pytorch.org/docs/stable/generated/torch.nn.MultiLabelSoftMarginLoss.html#torch.nn.MultiLabelSoftMarginLoss
BCEWithLogitsLoss
还多一个控制正负样本不均衡的pos_weight, 详见https://blog.csdn.net/ltochange/article/details/117790534
例子:
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
def validate_loss(output, target, weight=None, pos_weight=None):
output = F.sigmoid(output)
# 处理正负样本不均衡问题
if pos_weight is None:
label_size = output.size()[1]
pos_weight = torch.ones(label_size)
# 处理多标签不平衡问题
if weight is None:
label_size = output.size()[1]
weight = torch.ones(label_size)
val = 0
for li_x, li_y in zip(output, target):
for i, xy in enumerate(zip(li_x, li_y)):
x, y = xy
loss_val = pos_weight[i] * y * math.log(x, math.e) + (1 - y) * math.log(1 - x, math.e)
val += weight[i] * loss_val
return -val / (output.size()[0] * output.size(1))
weight = torch.Tensor([0.8, 1, 0.8])
loss = nn.MultiLabelSoftMarginLoss(weight=weight)
x = torch.Tensor([[0.8, 0.9, 0.3], [0.8, 0.9, 0.3], [0.8, 0.9, 0.3], [0.8, 0.9, 0.3]])
y = torch.Tensor([[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0]])
print(x.size())
print(y.size())
loss_val = loss(x, y)
print(loss_val.item())
validate_loss = validate_loss(x, y, weight=weight)
print(validate_loss.item())
loss = torch.nn.BCEWithLogitsLoss(weight=weight)
loss_val = loss(x, y)
print(loss_val.item())
输出结果:
torch.Size([4, 3])
torch.Size([4, 3])
0.4405062198638916
0.4405062198638916
0.440506249666214