import torch
x = torch.arange(24).view(2,3,4).float()
print(x)
"""
输出是这样的
[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]]
"""
如上,创建一个shape为(2,3,4)的张量。我们可以把它看成一个图片,图片的shape为
(channels=2, width=3, height=4)
1.所谓的mean就是取平均值,无非就是看取哪一个维度的平均值
例如,在第0个维度channels上取平均值:
import torch
x = torch.arange(24).view(2,3,4).float()
y1 = x.mean([0])
print(y1)
"""
x输出是这样的
[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]]
"""
"""
y1输出结果为:
[[ 6., 7., 8., 9.],
[10., 11., 12., 13.],
[14., 15., 16., 17.]]
"""
通过对比x和y1的输出结果可以看出:对于一个(channels,width, height)的张量,如果在channels上取均值的话,就是对每一个channel相加然后取平均值,它输出的shape为(width, height)。
以RGB图像为例,在channels上取均值,就是对RGB三个通道计算平均值,最后变成一个通道。