对一个三维数组的每一维度进行操作
1,dim=0
a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7]).view(2, 2, 2)
print(a)
mean = torch.mean(a, 0)
print(mean, mean.shape)
输出结果:
tensor([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
tensor([[2., 3.],
[4., 5.]]) torch.Size([2, 2])
2,dim=1
a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7]).view(2, 2, 2)
print(a)
mean = torch.mean(a, 1)
print(mean, mean.shape)
输出结果
tensor(
[[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
tensor(
[[1., 2.],
[5., 6.]]) torch.Size([2, 2])
3,dim=2
a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7]).view(2, 2, 2)
print(a)
mean = torch.mean(a, 2)
print(mean, mean.shape)
输出结果
tensor(
[[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
tensor(
[[0.5000, 2.5000],
[4.5000, 6.5000]]) torch.Size([2, 2])
补充,如果在函数中添加了True,表示要和原来数的维度一致,不够的用维度1来添加,如下
a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7]).view(2, 2, 2)
print(a)
mean = torch.mean(a, 2, True)
print(mean, mean.shape)
tensor([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
tensor([[[0.5000],
[2.5000]],
[[4.5000],
[6.5000]]]) torch.Size([2, 2, 1])
补充多维度变化
a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15]).view(2, 2, 2,2)
print(a)
mean = torch.mean(a, 0, True)
print(mean, mean.shape)
tensor([[[[ 0., 1.],
[ 2., 3.]],
[[ 4., 5.],
[ 6., 7.]]],
[[[ 8., 9.],
[10., 11.]],
[[12., 13.],
[14., 15.]]]])
tensor([[[[ 4., 5.],
[ 6., 7.]],
[[ 8., 9.],
[10., 11.]]]]) torch.Size([1, 2, 2, 2])
a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15]).view(2, 2, 2,2)
print(a)
mean = torch.mean(a, 1, True)
print(mean, mean.shape)
tensor([[[[ 0., 1.],
[ 2., 3.]],
[[ 4., 5.],
[ 6., 7.]]],
[[[ 8., 9.],
[10., 11.]],
[[12., 13.],
[14., 15.]]]])
tensor([[[[ 2., 3.],
[ 4., 5.]]],
[[[10., 11.],
[12., 13.]]]]) torch.Size([2, 1, 2, 2])
a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15]).view(2, 2, 2,2)
print(a)
mean = torch.mean(a, 2, True)
print(mean, mean.shape)
tensor([[[[ 0., 1.],
[ 2., 3.]],
[[ 4., 5.],
[ 6., 7.]]],
[[[ 8., 9.],
[10., 11.]],
[[12., 13.],
[14., 15.]]]])
tensor([[[[ 1., 2.]],
[[ 5., 6.]]],
[[[ 9., 10.]],
[[13., 14.]]]]) torch.Size([2, 2, 1, 2])
a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15]).view(2, 2, 2,2)
print(a)
mean = torch.mean(a, 3, True)
print(mean, mean.shape)
tensor([[[[ 0., 1.],
[ 2., 3.]],
[[ 4., 5.],
[ 6., 7.]]],
[[[ 8., 9.],
[10., 11.]],
[[12., 13.],
[14., 15.]]]])
tensor([[[[ 0.5000],
[ 2.5000]],
[[ 4.5000],
[ 6.5000]]],
[[[ 8.5000],
[10.5000]],
[[12.5000],
[14.5000]]]]) torch.Size([2, 2, 2, 1])
a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15,0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15]).view(2, 2, 2,2,2)
print(a)
mean = torch.mean(a, 3, True)
print(mean, mean.shape)
tensor([[[[[ 0., 1.],
[ 2., 3.]],
[[ 4., 5.],
[ 6., 7.]]],
[[[ 8., 9.],
[10., 11.]],
[[12., 13.],
[14., 15.]]]],
[[[[ 0., 1.],
[ 2., 3.]],
[[ 4., 5.],
[ 6., 7.]]],
[[[ 8., 9.],
[10., 11.]],
[[12., 13.],
[14., 15.]]]]])
tensor([[[[[ 1., 2.]],
[[ 5., 6.]]],
[[[ 9., 10.]],
[[13., 14.]]]],
[[[[ 1., 2.]],
[[ 5., 6.]]],
[[[ 9., 10.]],
[[13., 14.]]]]]) torch.Size([2, 2, 2, 1, 2])