交叉熵损失函数代码中的那些坑

最近在用交叉熵损失函数,但是却频频出现bug,这里把坑都记录一下,避免以后又再一次掉进去,也希望能帮助到掉进去的人出来。

  1. torch.nn.CrossEntropyLoss()
    首先,这是一个,在用的时候需要先创建对象,然后把参数传给对象。例如
# 正确示例
loss = torch.nn.CrossEntropyLoss()
loss = loss(predict, target.long())
# 或者
loss = torch.nn.CrossEntropyLoss()(predict, target.long())
# 错误示例
loss = torch.nn.CrossEntropyLoss(predict, target.long())

如果错用了上述“错误示例”,就会报错"RuntimeError: Boolean value of Tensor with more than one value is ambiguous"

其次,要说的是传给对象的参数,第一个predict是网络的直接输出,是含有正数、负数的一些乱七八糟的数,如果是3D的,predict.shape应该是[B, Classes, D, H, W],如果是2D的,predict.shape应该是[B, Classes, H, W],其中B是batch_size,Classes是类别数,是几分类就是几(考虑背景)。第二个参数target是标签,是索引标签,3D中的大小是[B, D, H, W],2D中的大小是[B, H, W],可能你会问了,predict和target的大小不一样,怎么计算loss呢?因为torch.nn.CrossEntropyLoss()内部封装了将target转换成和predict大小一样的onehot编码的代码,所以只需要传进去索引编码就可以,不用自己转换成onehot编码

代码示例
predict = torch.randn(3, 2)
labels = torch.FloatTensor([[0, 1], [1, 0], [1, 0])  # onehot编码格式
label2 = torch.LongTensor([1, 0, 0])  # 索引标签
loss = torch.nn.CrossEntropyLoss()(predict, label2)

同时需要注意,传进去的target应该是LongTensor类型的,如果不是,需要强制转换一下,否则就会报错,应该是这个错误“”RuntimeError: Expected object of type torch.cuda.LongTensor but found type torch.cuda.FloatTensor for argument #2 ‘target’“”
2. torch.nn.BCELoss()

代码示例
predict = torch.randn(3, 2)
labels = torch.FloatTensor([[0, 1], [1, 0], [1, 0])
loss = torch.nn.BCELoss()(torch.nn.Sigmoid(predict), labels)

由上述代码可以看出,torch.nn.BCELoss()没有封装将索引编码转换成onehot编码的代码,需要自己现将索引编码转换成onehot编码,然后在上传参数。并且,predict参数也不再是网络的直接输出,而是经过sigmoid之后的,某个维度上所有值之和为1。
此外,将索引编码转换成onehot编码,可以使用**torch.nn.functional.onehot()**函数。
3. torch.nn.BCEWithLogitsLoss()

代码示例
predict = torch.randn(3, 2)
labels = torch.FloatTensor([[0, 1], [1, 0], [1, 0])
loss = torch.nn.BCEWithLogitsLoss()(predict, labels)

再看torch.nn.BCEWithLogitsLoss()这个类,它需要上传的参数是网络的直接输出和onehot编码格式的target。
4. torch.nn.functional.binary_cross_entropy()
需要注意的是,上述前三个都是类,传参数时需要先创建对象,然后将参数传给对象,但是torch.nn.functional里面的是函数,直接传参数即可。要注意这一点,避免不必要的bug。

代码示例
predict = torch.randn((3, 2), requires_grad=True)
target = torch.rand((3, 2), requires_grad=False)
loss = F.binary_cross_entropy(F.sigmoid(predict), target)

这是官方文档的一个例子,首先,该函数适用于二分类的交叉熵损失函数计算,其次对于predict参数,是需要经过sigmoid之后的,target是onehot编码格式的,因为target和predict的大小是一样的。如果是3D,那么predict和target的shape都应该是[B, Classes, D, H, W],如果是2D,那么它们的大小都应该是[B, Classes, H, W],其中各个字母的含义和上述相同。

注:不知道该怎么用时,查官方文档真的很有用,很有用,很有用!

  • 12
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值