关于torch.max()、torch.mean()、torch.cat()的理解

本文介绍了PyTorch中用于处理张量的三个关键函数:torch.max()用于找到张量的最大值,可指定求解的维度;torch.mean()计算平均值,同样支持按维度计算;torch.cat()则用于在特定维度上连接多个张量。每个函数的参数和使用示例均被详细阐述。
摘要由CSDN通过智能技术生成

torch.max()的理解

torch.max()一共有两种形式,如下:

第一种:torch.max(input)

这种形式直接求出input中所有数的最大值,输出是一个数,且output.dim()=0,即无论input是几维,输出都为一个0维的数,注意input需为张量。

示例代码如下:

>>>x1 = torch.tensor([1, 2, 3])
>>>torch.max(x1)
tensor(3)

>>>x2 = torch.tensor([[[1, 2],[3, 4]],
                  [[5, 6],[7, 8]]])
>>>torch.max(x2)
tensor(8)
第二种:torch.max(input, dim, keepdim=False, *, out=None)

这种形式可以根据需要求第几维度上的最大值,且可以选择输出的维度是否改变,返回最大值和第几维度上最大值索引。
dim,求第dim维度的最大值,例如dim=0,求第0维上的最大值,dim=1,求第1维上的最大值;
keepdim,当keepdim=False时,输出维度input.dim()改变,否则不变。
示例代码如下:

>>>a = torch.tensor([[[[ 1.,  2.]],
                   [[ 3.,  7.]],
                   [[ 5.,  6.]]],

                  [[[ 7.,  8.]],
                   [[ 13., 10.]],
                   [[11., 12.]]]])
>>>a.shape
torch.Size([2, 3, 1, 2])

>>>max1_a = torch.max(a, 1)
>>>max_a
torch.return_types.max(
values=tensor([[[ 5.,  7.]],

        [[13., 12.]]]),
indices=tensor([[[2, 1]],

        [[1, 2]]]))
        
>>>max_a[0]
tensor([[[ 5.,  7.]],

        [[13., 12.]]])
        
>>>max_a[0].shape
torch.Size([2, 1, 2])

>>>max2_a = torch.max(a, 1, keepdim=True)
>>>max2_a[0].shape
torch.Size([2, 1, 1, 2])

torch.mean()的理解

torch.mean()同 torch.max()一样也是有两种用法,一个是求输入种所有数的平均值,一个是求输入种第几维度的平均值。
使用时同 torch.max()一样,见上torch.max()的理解,只需将max换成mean就行,这里就不举例说明了。

torch.cat()的理解

torch.cat(tensors, dim=0, *, out=None)

torch.cat(),就是在第几维度上连接张量,输入为张量元组,例如dim=0,在第0维连接,dim=1,在第1维上连接。
代码示例如下:

>>>m = torch.tensor([[[[ 1.,  2.]],
                   [[ 3.,  4.]],
                   [[ 5.,  6.]]],

                  [[[ 7.,  8.]],
                   [[ 9., 10.]],
                   [[11., 12.]]]])      
>>>n = torch.tensor([[[[ 1.,  2.]],
                   [[ 1.,  2.]],
                   [[ 2.,  1.]]],

                  [[[ 4.,  8.]],
                   [[ 5.,  8.]],
                   [[ 6.,  7.]]]])
>>>n.shape
torch.Size([2, 3, 1, 2])     
                
>>>torch.cat((m, n), dim=0)
tensor([[[[ 1.,  2.]],

         [[ 3.,  4.]],

         [[ 5.,  6.]]],


        [[[ 7.,  8.]],

         [[ 9., 10.]],

         [[11., 12.]]],


        [[[ 1.,  2.]],

         [[ 1.,  2.]],

         [[ 2.,  1.]]],


        [[[ 4.,  8.]],

         [[ 5.,  8.]],

         [[ 6.,  7.]]]])

>>>torch.cat((m, n), dim=0).shape
torch.Size([4, 3, 1, 2])

>>>torch.cat((m, n), dim=1)
tensor([[[[ 1.,  2.]],

         [[ 3.,  4.]],

         [[ 5.,  6.]],

         [[ 1.,  2.]],

         [[ 1.,  2.]],

         [[ 2.,  1.]]],


        [[[ 7.,  8.]],

         [[ 9., 10.]],

         [[11., 12.]],

         [[ 4.,  8.]],

         [[ 5.,  8.]],

         [[ 6.,  7.]]]])
         
>>>torch.cat((m, n), dim=1).shape
torch.Size([2, 6, 1, 2])
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值