PYTORCH 损失函数 (EXPECTED TO BE IN RANGE OF [-1, 0], BUT GOT 1)-检查batchsize

一般情况下的分类问题计算交叉熵损失

import torch.nn as nn
import torch
func=nn.CrossEntropyLoss()
a=torch.Tensor([[ 0.0606,0.1610,0.2990,0.2101, 0.5104],
                [0.6388,0.4053, 0.4196, 0.7060, 0.2793],
                [ 0.3973,0.6114, 0.1127, 0.7732, 0.0592]])
b=[3,1,0]
b=torch.Tensor(b)
loss=func(a,b.long())
loss=func(a,b.long())
print("总loss:",loss)

以上是batchsize为3的情况,但是batchsize为1(如下)会报错:

a1=torch.Tensor([ 0.0606,0.1610,0.2990,0.2101, 0.5104])
a2=torch.Tensor([0.6388,0.4053, 0.4196, 0.7060, 0.2793])
a3=torch.Tensor([ 0.3973,0.6114, 0.1127, 0.7732, 0.0592])

b1=torch.Tensor([3])
b2=torch.Tensor([1])
b3=torch.Tensor([0])

loss_1=func(a1,b1.long())
loss_2=func(a2,b2.long())
loss_3=func(a3,b3.long())

原因是a.type: torch.Size([3, 5])的大小分出来是a1.type: torch.Size([5]),需要维度为[1,5],解决方法:

a1=torch.unsqueeze(a1,0)
a2=torch.unsqueeze(a2,0)
a3=torch.unsqueeze(a3,0)

print("loss1:",loss_1)
print("loss2:",loss_2)
print("loss3",loss_3)
loss_sum=loss_1+loss_2+loss_3
print("loss_sum",loss_sum/3)

文章主要参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值