pytorch中常用的损失函数

  • 1. pytorch中常用的损失函数列举:

pytorch中的nn模块提供了很多可以直接使用的loss函数, 比如MSELoss(), CrossEntropyLoss(), NLLLoss() 等

官方链接: https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html

pytorch中常用的损失函数
损失函数名称适用场景
torch.nn.MSELoss()均方误差损失回归
torch.nn.L1Loss()平均绝对值误差损失回归
torch.nn.CrossEntropyLoss()交叉熵损失多分类
torch.nn.NLLLoss()负对数似然函数损失多分类
torch.nn.NLLLoss2d()图片负对数似然函数损失图像分割
torch.nn.KLDivLoss()KL散度损失回归
torch.nn.BCELoss()二分类交叉熵损失二分类
torch.nn.MarginRankingLoss()评价相似度的损失 
torch.nn.MultiLabelMarginLoss()多标签分类的损失多标签分类
torch.nn.SmoothL1Loss()平滑的L1损失回归
torch.nn.SoftMarginLoss()多标签二分类问题的损失

多标签二分类

 

 

  • 2. 比较CrossEntropyLoss() 和NLLLoss():

(1). CrossEntropyLoss():

torch.nn.CrossEntropyLoss(weight=None,   # 1D张量,含n个元素,分别代表n类的权重,样本不均衡时常用
                          size_average=None, 
                          ignore_index=-100, 
                          reduce=None, 
                          reduction='mean' )

参数:

weight: 1D张量,含n个元素,分别代表n类的权重,样本不均衡时常用, 默认为None.

计算公式:

weight = None时:

  

weight ≠ None时:

输入:

output: 网络未加softmax的输出

target: label值(0,1,2  不是one-hot)

代码:

loss_func = CrossEntropyLoss(weight=torch.from_numpy(np.array([0.03,0.05,0.19,0.26,0.47])).float().to(device) ,size_average=True)
loss = loss_func(output, target)

(2). NLLLoss():

torch.nn.NLLLoss(weight=None, 
                size_average=None, 
                ignore_index=-100,
                reduce=None, 
                reduction='mean')

输入:

output: 网络在logsoftmax后的输出  

target: label值(0,1,2   不是one-hot)

代码:

loss_func = NLLLoss(weight=torch.from_numpy(np.array([0.03,0.05,0.19,0.26,0.47])).float().to(device) ,size_average=True)
loss = loss_func(output, target)

 

(3). 二者总结比较:

总之, CrossEntropyLoss() = softmax + log + NLLLoss() = log_softmax + NLLLoss(), 具体等价应用如下:

####################---CrossEntropyLoss()---#######################

loss_func = CrossEntropyLoss()
loss = loss_func(output, target)

####################---Softmax+log+NLLLoss()---####################

self.softmax = nn.Softmax(dim = -1)

x = self.softmax(x)
output = torch.log(x)

loss_func = NLLLoss()
loss = loss_func(output, target)

####################---LogSoftmax+NLLLoss()---######################

self.log_softmax = nn.LogSoftmax(dim = -1)

output = self.log_softmax(x)

loss_func = NLLLoss()
loss = loss_func(output, target)

 

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值