Pytorch学习(十一)Pytorch中.item()的用法

我第一次接触.item()是在做图像分类任务中,计算loss的时候。total_loss = total_loss + loss.item()

1. .item()的用法

.item()用于在只包含一个元素的tensor中提取值,注意是只包含一个元素,否则的话使用.tolist()

x = torch.tensor([1])
print(x.item())
y = torch.tensor([2,3,4,5])
print(y.item())
# 输出结果如下
1
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-16-e884e6ada778> in <module>
      2 print(x.item())
      3 y = torch.tensor([2,3,4,5])
----> 4 print(y.item())

ValueError: only one element tensors can be converted to Python scalars

由结果看出,print(y.item())报错了,报错说只有一个元素的tensor才能使用.item()转换成标量,所以我们把y.item()改成.tolist()试试。

x = torch.tensor([1])
print(x.item())
y = torch.tensor([2,3,4,5])
print(y.tolist())
# 输出结果
1
[2, 3, 4, 5]
2. 为何要使用.item()

在训练时统计loss变化时,会用到loss.item(),能够防止tensor无线叠加导致的显存爆炸

3. 为何loss还需要乘batch_size呢

比如这个语句: train_loss += loss.item() * images.size(0) ,为什么总的损失train_loss不直接累加loss.item(),还在后边乘上images.size(0)呢?
loss.item()应该是一个batch size的平均损失,×images.size(0)那就是一个batch size的总损失,所以train_loss很可能是求一个epoch的loss之和。

  • 51
    点赞
  • 121
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值