pytorch之SmoothL1Loss原理与用法

本文详细介绍了SmoothL1Loss损失函数的数学原理,通过图像展示了其分段线性特性。当误差增大时,损失增长相对平缓,避免了平方误差的爆炸性增长。此外,提供了Python代码示例来演示如何在PyTorch中使用该损失函数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

官方说明:

在这里插入图片描述

解读

我们直接看那个loss计算公式 l n l_n ln,可以发现,是一个分段函数,我们将绝对值差视为一个变量 z z z,那么这个变量是大于0的,即分段函数只在大于等于0处有定义,有图像。我们再来看看分段点,就是beta。

有意思的是,在分段函数和这个分段点有关,在第一个公式(左边分段函数)中,函数值小于等于 0.5 z 0.5z 0.5z,因为除了beta。右边分段函数中,大于等于 0.5 z 0.5z 0.5z。所以是连续的,所以叫做Smooth。

而且beta固定下来的时候,当 z z z很大时,损失是线性函数,也就是说损失不会像MSE那样平方倍的爆炸。

总结就是:前半段随着 z z z的增长,损失增长得非常缓慢,后面快了一点点,但是也仍然是线性的。

图像

plt.figure(figsize=(20,8),dpi=80)
beta=[0.5,1,2,3]
for i in range(len(beta)):
    
    x1=np.linspace(0,beta[i],21)
    y1=0.5*x1*x1/beta[i]
    x2=np.linspace(beta[i],6,21)
    y2=x2-0.5*beta[i]
    plt.plot(np.hstack([x1,x2]),np.hstack([y1,y2]),label="beta:{}".format(beta[i]))
    
plt.xlabel("the absolute element-wise error")
plt.ylabel("the real loss")
plt.legend()

在这里插入图片描述

用法

import torch
import torch.nn as nn
a=[1,2,3]
b=[3,1,9]
loss_fn=nn.SmoothL1Loss()
loss_fn(torch.tensor(a,dtype=torch.float32),torch.tensor(b,dtype=torch.float32))

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

音程

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值