torch.mean()、torch.max() 和 nn.AdaptiveAvgPool2d()、nn.AdaptiveMaxPool2d() 用法小结


🤵 AuthorHorizon John

编程技巧篇各种操作小结

🎇 机器视觉篇会变魔术 OpenCV

💥 深度学习篇简单入门 PyTorch

🏆 神经网络篇经典网络模型

💻 算法篇再忙也别忘了 LeetCode


torch.randn()

首先创建1个 2个通道(3 x 5)大小的张量 :

x = torch.randn(1, 2, 3, 5)
print(x)

输出结果:

tensor([[[[ 0.9529, -0.3196,  0.3037,  0.0507,  0.3179],
          [-2.0785, -0.8697,  0.2880,  0.4416,  0.7148],
          [-1.7157,  0.6767, -0.5997,  1.1221,  0.2088]],

         [[-1.1475,  0.2046, -0.7609, -1.2099, -0.5255],
          [-0.1989, -0.6862, -0.6264,  0.1123,  0.3257],
          [ 1.5529, -0.6140, -0.0471, -0.6173, -0.9239]]]])

torch.mean()

dim = 0 :按行求平均值,返回结果为(1,列),通道数不会发生变化
dim = 0 :按列求平均值,返回结果为(行,1),通道数会发生变化
默认 :输出所有值的平均值

avg_out = torch.mean(x)
print(avg_out)

avg_out0 = torch.mean(x, dim=0, keepdim=True)
print(avg_out0)

avg_out1 = torch.mean(x, dim=1, keepdim=True)
print(avg_out1)

输出结果:

tensor(-0.1889)

tensor([[[[ 0.9529, -0.3196,  0.3037,  0.0507,  0.3179],
          [-2.0785, -0.8697,  0.2880,  0.4416,  0.7148],
          [-1.7157,  0.6767, -0.5997,  1.1221,  0.2088]],

         [[-1.1475,  0.2046, -0.7609, -1.2099, -0.5255],
          [-0.1989, -0.6862, -0.6264,  0.1123,  0.3257],
          [ 1.5529, -0.6140, -0.0471, -0.6173, -0.9239]]]])
          
tensor([[[[-0.0973, -0.0575, -0.2286, -0.5796, -0.1038],
          [-1.1387, -0.7779, -0.1692,  0.2770,  0.5203],
          [-0.0814,  0.0313, -0.3234,  0.2524, -0.3575]]]])

torch.max()

dim = 0 :按行求最大值,返回结果为(1,列),通道数不会发生变化
dim = 0 :按列求最大值,返回结果为(行,1),通道数会发生变化
默认 :输出所有值的最大值

max_out = torch.max(x)
print(max_out)

max_out0, _ = torch.max(x, dim=0, keepdim=True)
print(max_out0)

max_out1, _ = torch.max(x, dim=1, keepdim=True)
print(max_out1)

输出结果:

tensor(1.5529)

tensor([[[[ 0.9529, -0.3196,  0.3037,  0.0507,  0.3179],
          [-2.0785, -0.8697,  0.2880,  0.4416,  0.7148],
          [-1.7157,  0.6767, -0.5997,  1.1221,  0.2088]],

         [[-1.1475,  0.2046, -0.7609, -1.2099, -0.5255],
          [-0.1989, -0.6862, -0.6264,  0.1123,  0.3257],
          [ 1.5529, -0.6140, -0.0471, -0.6173, -0.9239]]]])
          
tensor([[[[ 0.9529,  0.2046,  0.3037,  0.0507,  0.3179],
          [-0.1989, -0.6862,  0.2880,  0.4416,  0.7148],
          [ 1.5529,  0.6767, -0.0471,  1.1221,  0.2088]]]])

另外:

out = torch.max(x, dim=0, keepdim=True)
print(out)

输出结果:

torch.return_types.max(
values=tensor([[[[ 0.9529, -0.3196,  0.3037,  0.0507,  0.3179],
          [-2.0785, -0.8697,  0.2880,  0.4416,  0.7148],
          [-1.7157,  0.6767, -0.5997,  1.1221,  0.2088]],

         [[-1.1475,  0.2046, -0.7609, -1.2099, -0.5255],
          [-0.1989, -0.6862, -0.6264,  0.1123,  0.3257],
          [ 1.5529, -0.6140, -0.0471, -0.6173, -0.9239]]]]),
indices=tensor([[[[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]],

         [[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]]]]))

nn.AdaptiveAvgPool2d()

给定输出 feature map 的大小,输出前后的 通道数不会发生变化

avg_pool0 = nn.AdaptiveAvgPool2d(1)
avg_pool0 = avg_pool0(x)
print(avg_pool0)

avg_pool1 = nn.AdaptiveAvgPool2d((2, 2))
avg_pool1 = avg_pool1(x)
print(avg_pool1)

输出结果:

tensor([[[[-0.0337]],

         [[-0.3442]]]])
         
tensor([[[[-0.2872,  0.3528],
          [-0.7165,  0.3626]],

         [[-0.5359, -0.4475],
          [-0.1033, -0.2961]]]])

nn.AdaptiveMaxPool2d()

给定输出 feature map 的大小,输出前后的 通道数不会发生变化

max_pool0 = nn.AdaptiveMaxPool2d(1)
max_pool0 = max_pool0(x)
print(max_pool0)

max_pool1 = nn.AdaptiveMaxPool2d((2, 2))
max_pool1 = max_pool1(x)
print(max_pool1)

输出结果:

tensor([[[[1.1221]],

         [[1.5529]]]])
tensor([[[[0.9529, 0.7148],
          [0.6767, 1.1221]],

         [[0.2046, 0.3257],
          [1.5529, 0.3257]]]])


  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Horizon John

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值