pytorch两种返回分类最大值的方法

前言

本文是为了回忆两种输出分类结果的方法,记录两种的用法。

正文

src=torch.rand(3,10,512)
print(src.shape)
lin=nn.Linear(512,521)
out=lin(src)
print(out.shape)
pred=out.argmax(2)
print(pred)
_,pres=torch.max(out,dim=2)
print(pres)

第一种方法是利用argmax来进行返回,获取的是最大值对应的下标。输出如下所示:

tensor([[298, 298,  22,  22, 298, 312, 472, 298,  22, 491],
        [491, 298, 491, 298, 196,  22, 298, 156, 472, 491],
        [ 22,  88,  96, 156, 298, 491,  22,  88, 110, 298]])

第二种方法是利用的max函数,第一个返回的参数为最大值的具体张量数值,第二个参数是对应的下标,具体的用法与上方的代码块一致,输出如下:

tensor([[298, 298,  22,  22, 298, 312, 472, 298,  22, 491],
        [491, 298, 491, 298, 196,  22, 298, 156, 472, 491],
        [ 22,  88,  96, 156, 298, 491,  22,  88, 110, 298]])
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值