pytorch中数组维度理解与numpy中类似,pytorch中维度用dim
表示,numpy中用axis
表示
这里主要想说下维度的变化。
dim = x ,表示在第x为上进行操作,那个维度会发生变化。
一、二维数组
1. 两个二维数组的拼接
维度为(2,3)与(2,4)的数组拼接后的维度是(2,7)
import torch
a = torch.Tensor(np.arange(6).reshape(2,3))
b = torch.Tensor(np.arange(8).reshape(2,4))
print(a,'\n ',a.shape)
print(b,'\n',b.shape)
c = torch.cat((a,b),dim = 1)
print('concatenate:\n',c,'\n',c.shape)
结果
tensor([[0., 1., 2.],
[3., 4., 5.]])
a: torch.Size([2, 3])
tensor([[0., 1., 2., 3.],
[4., 5., 6., 7.]])
torch.Size([2, 4])
concatenate:
tensor([[0., 1., 2., 0., 1., 2., 3.],
[3., 4., 5., 4., 5., 6., 7.]])
torch.Size([2, 7])
2. 二维数组求sum、max等
dim = 0,第一个维度划掉,得到一个一维向量。比如,a是(2,3),dim = 0,得到的结果是(3,)维的;如果dim=1,得到的结果是(2,)
print('sum dim=0',torch.sum(a,dim=0))
print('sum dim=1',torch.sum(a,dim=1))
print('******* max *****')
print('max dim=0',torch.max(a,dim=0))
print('max dim=1',torch.max(a,dim=1))
输出
tensor([[0., 1., 2.],
[3., 4., 5.]])
torch.Size([2, 3])
sum dim=0 tensor([3., 5., 7.])
sum dim=1 tensor([ 3., 12.])
******* max *****
max dim=0 torch.return_types.max(
values=tensor([3., 4., 5.]),
indices=tensor([1, 1, 1]))
max dim=1 torch.return_types.max(
values=tensor([2., 5.]),
indices=tensor([2, 2]))
二、三维数组
1. 两个三维数组的拼接
两个三位数组拼接,有个要求,除了dim维,其余维的维度要相同。
- 比如 a是(2,3,4),b是(3,2,4)那么a与b无论在哪个维上都不能拼接。因为它们没有两个相同的维度。
- 如果a与b维度相同,都是(2,3,4),那么他们无论在哪个维上都可以拼接。dim = 0,结果是
(4,3,4)
,dim = 1,结果是(2,6,4)
,dim =2,结果是(2,3,8)
- dim = x,就将两个数组dim维上的数字相加,得到最终输出维度。
a = torch.Tensor(np.arange(24).reshape(2,3,4))
b = torch.Tensor(np.arange(24,48).reshape(2,3,4))
print(a,'\n ',a.shape)
print(b,'\n',b.shape)
c = torch.cat((a,b),dim = 2)
print('concatenate:\n',c,'\n',c.shape)
输出结果
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]])
torch.Size([2, 3, 4])
tensor([[[24., 25., 26., 27.],
[28., 29., 30., 31.],
[32., 33., 34., 35.]],
[[36., 37., 38., 39.],
[40., 41., 42., 43.],
[44., 45., 46., 47.]]])
torch.Size([2, 3, 4])
concatenate:
tensor([[[ 0., 1., 2., 3., 24., 25., 26., 27.],
[ 4., 5., 6., 7., 28., 29., 30., 31.],
[ 8., 9., 10., 11., 32., 33., 34., 35.]],
[[12., 13., 14., 15., 36., 37., 38., 39.],
[16., 17., 18., 19., 40., 41., 42., 43.],
[20., 21., 22., 23., 44., 45., 46., 47.]]])
torch.Size([2, 3, 8])
2. 三维数组求sum、max等
- 类似于二维数组,会消去dim维度
- shape=(2,3,4)的数组,在dim=0上求和或者取最大后,结果的shape = (3,4)
- pytorch求max,同时返回两个值(max,indices)
a = torch.Tensor(np.arange(24).reshape(2,3,4))
print(a,'\n',a.shape)
print('sum dim=0',torch.sum(a,dim=0))
print('sum dim=1',torch.sum(a,dim=1))
print('sum dim=2',torch.sum(a,dim=2))
print('******* max *****')
print('max dim=0',torch.max(a,dim=0))
print('max dim=1',torch.max(a,dim=1))
print('max dim=2',torch.max(a,dim=2))
结果
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]])
torch.Size([2, 3, 4])
sum dim=0 tensor([[12., 14., 16., 18.],
[20., 22., 24., 26.],
[28., 30., 32., 34.]])
sum dim=1 tensor([[12., 15., 18., 21.],
[48., 51., 54., 57.]])
sum dim=2 tensor([[ 6., 22., 38.],
[54., 70., 86.]])
******* max *****
max dim=0 torch.return_types.max(
values=tensor([[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]),
indices=tensor([[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]]))
max dim=1 torch.return_types.max(
values=tensor([[ 8., 9., 10., 11.],
[20., 21., 22., 23.]]),
indices=tensor([[2, 2, 2, 2],
[2, 2, 2, 2]]))
max dim=2 torch.return_types.max(
values=tensor([[ 3., 7., 11.],
[15., 19., 23.]]),
indices=tensor([[3, 3, 3],
[3, 3, 3]]))