pytorch复现loss遇到的问题

本文详细介绍了如何在PyTorch框架下复现交叉熵损失函数,并解决在实现过程中遇到的detach()问题及大量nan值的产生。通过将numpy函数替换为torch函数,解决了数据类型不匹配的问题,同时使用clamp()函数限制loss计算值的范围,避免了数值不稳定导致的nan值。
摘要由CSDN通过智能技术生成

复现交叉熵

首先定义了函数对照公式实现了交叉熵的功能

def CrossEntropy(inputs, targets):
        return np.sum(np.nan_to_num(-targets*(np.log(inputs)))

运用到项目代码中出现了detach()问题,且因为类型非variable 无法更新梯度,无法backward
由于我的项目代码使用的torch框架,内部数据类型全是tensor,而用了numpy之后的数据类型全部变成了array,遂将numpy的函数全替换为torch的函数,即可针对tensor运算
且新的变量自动全是variable类型,可顺利反向传播
实现好后运行结果出现大量的nan,无法正常运算,使用clamp限制loss计算值的范围
class CrossEntropy(nn.Module):
    def __init__(self):
        super(CrossEntropy, self).__init__()
        
    def forward(self, inputs, targets):
        ## torch中要想实现backward就不能使用np,不能用array,只能使用tensor,只有tensor才有requires_grad参数
        loss1=-targets*(torch.log(inputs)).cuda()
        loss=torch.sum(loss1.clamp(min=0.0001,max=1.0))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值