torch.argmax( ,dim=1)

**

torch.argmax()

**
维度本身就很难理解,作为一名小白,这篇用来给同样有维度困扰的你们,里面的例子引用了:https://blog.csdn.net/weixin_43869268/article/details/107624108

x = torch.randn(2,4,5)

print(x)
> tensor([[[-0.2377, -1.0894,  1.7376, -1.7329,  0.4314],
         [-1.4899, -0.7765, -0.8258, -0.5007,  0.3901],
         [ 0.5550,  1.1423, -1.3531, -0.7470,  0.5840],
         [-0.7997,  0.0681,  0.7373, -0.0090, -0.9902]],

        [[-0.6789, -1.6443, -0.3977, -0.0783,  1.6161],
         [-1.5234, -1.1999, -0.9181,  0.8778,  0.5733],
         [ 1.1922, -0.4699, -0.2262,  0.9180,  0.0642],
         [-0.6872,  0.3727,  1.3596, -0.8487, -0.8330]]])

print(x.argmax(dim=0))
> tensor([[0, 0, 0, 1, 1],
        [0, 0, 0, 1, 1],
        [1, 0, 1, 1, 0],
        [1, 1, 1, 0, 1]])


print(x.argmax(dim=2))
> tensor([[2, 4, 1, 2],
        [4, 3, 0, 2]])

看到这个例子困扰了我好久,到底怎么算出来的,下面用我的语言来解释一下,如果不对,欢迎指正~

首先例子中的X的维度是2 x 4 x 5(通过查中括号内有几堆数据来判定,比如最里层有5堆数据,所以是5,这都不是这篇文章的重点啦~如果不明白请评论留言我看到就会解答)

(1)那么当dim=0的时候(最外层的两堆进行比较)也就是:

        [[-0.2377, -1.0894,  1.7376, -1.7329,  0.4314],
         [-1.4899, -0.7765, -0.8258, -0.5007,  0.3901],
         [ 0.5550,  1.1423, -1.3531, -0.7470,  0.5840],
         [-0.7997,  0.0681,  0.7373, -0.0090,-0.9902]],

和

        [[-0.6789, -1.6443, -0.3977, -0.0783,  1.6161],
         [-1.5234, -1.1999, -0.9181,  0.8778,  0.5733],
         [ 1.1922, -0.4699, -0.2262,  0.9180,  0.0642],
         [-0.6872,  0.3727,  1.3596, -0.8487, -0.8330]]

第一堆中的每个元素与第二堆中对应位置的元素进行比较,若第一堆的>第二堆的,则输出0,否则输出1。
比如:-0.2377>-0.6789,所以第一个位置输出0,以此类推,得到dim=0时的tensor。

print(x.argmax(dim=0))
> tensor([[0, 0, 0, 1, 1],
        [0, 0, 0, 1, 1],
        [1, 0, 1, 1, 0],
        [1, 1, 1, 0, 1]])

(2)当dim=2时(最内层进行比较),也就是:

         [-0.2377, -1.0894,  1.7376, -1.7329,  0.4314],
和
         [-1.4899, -0.7765, -0.8258, -0.5007,  0.3901],
和
         [ 0.5550,  1.1423, -1.3531, -0.7470,  0.5840],
和
         [-0.7997,  0.0681,  0.7373, -0.0090, -0.9902]
进行比较,得出输出的第一行
         

         
         [-0.6789, -1.6443, -0.3977, -0.0783,  1.6161],
和
         [-1.5234, -1.1999, -0.9181,  0.8778,  0.5733],
和
         [ 1.1922, -0.4699, -0.2262,  0.9180,  0.0642],
和
         [-0.6872,  0.3727,  1.3596, -0.8487, -0.8330]
进行比较,得到输出的第二行
         

比如第一行中,第三个位置的元素最大,索引是2(索引从0开始计:第一行0,1,2,3,4)所以第一行第一个元素是2,以此类推…
[-0.2377, -1.0894, 1.7376, -1.7329, 0.4314]

最终得到

print(x.argmax(dim=2))
> tensor([[2, 4, 1, 2],
        [4, 3, 0, 2]])

好啦!大功告成!希望对你们也能有所帮助

  • 45
    点赞
  • 38
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值