pytorch学习笔记(五)——pytorch基础之维度变换
目录
维度重塑----view/reshape
view与reshape用法基本一致,view是pytorch0.3版本的api,在pytorch0.4版本为保持与numpy中重塑维度函数名一致,又引入了reshape
In [1]: import torch
In [2]: a = torch.rand(4,1,28,28)
#随机生成一个四维度的tensor 用来表示灰度图
#分别对应 batchsize channel height width
In [3]: a.shape
Out[3]: torch.Size([4, 1, 28, 28])
In [4]: a.view(4,28*28)#进行维度重塑 只要 乘积一样就不会报错
Out[4]:
tensor([[0.3060, 0.2680, 0.3763, ..., 0.6596, 0.5645, 0.8772],
[0.7751, 0.9969, 0.0864, ..., 0.7230, 0.4374, 0.1659],
[0.5146, 0.5350, 0.1214, ..., 0.2056, 0.2646, 0.8539],
[0.5737, 0.5637, 0.2420, ..., 0.7731, 0.6198, 0.6113]])
#相当于把channel height width合并
In [5]: a.view(4,28*28).shape
Out[5]: torch.Size([4, 784])
In [6]: a.view(4*28,28).shape
Out[6]: torch.Size([112, 28])
#相当于把 batchsize channel height合并
In [7]: a.view(4*1,28,28).shape
Out[7]: torch.Size([4, 28, 28])
#相当于把 batchsize channel 合并
In [8]: b = a.view(4,784)
In [9]: b.view(4,28,28,1)
弊端:采用view之后 如果要还原数据 必须要知道数据存储的顺序 不然顺序搞错会导致数据的污染 如最后两行代码
增删维度----squeeze/unsqueeze
使用unsqueeze增加维度
In [12]: a.unsqueeze(0).shape
Out[12]: torch.Size([1, 4, 1, 28, 28])
#unsqueeze 中传入的参数相当于index 表示要增加维度所在的下标
#unsqueeze(0)表示在下标0处增加一个维度
In [13]: a.shape
Out[13]: torch.Size([4, 1, 28, 28])
In [15]: a.unsqueeze(-1).shape
Out[15]: torch.Size([4, 1, 28, 28, 1])
#传入-1 表示倒数第一个下标 传入-2 表示倒数第二个下标
#unsqueeze(-1)也就是在最后一个下标上增加一个维度
In [16]: a.unsqueeze(4).shape
Out[16]: torch.Size([4, 1, 28, 28, 1])
#unsqueeze(4)在第5个下标上增加一个维度 因为下标从0开始
In [17]: a.unsqueeze(-4).shape
Out[17]: torch.Size([4, 1, 1, 28, 28])
#unsqueeze(-4)在倒数第四个下标上增加一个维度 因为加上一个维度后共有五个维度 倒数第四个维度 即为第二个维度
In [18]: a.unsqueeze(-5).shape
Out[18]: torch.Size([1, 4, 1, 28, 28])
#unsqueeze(-5)在倒数第五个下标上增加一个维度 因为加上一个维度后共有五个维度 倒数第五个维度 即为第一个维度
In [19]: a.unsqueeze(5).shape
#因为加上一个维度后一共也就5个维度 所以使用下标5是错误的
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-19-b54eab361a50> in <module>
----> 1 a.unsqueeze(5).shape
RuntimeError: Dimension out of range (expected to be in range of [-5, 4], but got 5)
In [20]: a = torch.tensor([1.2,2.3])
In [21]: a.unsqueeze(-1).shape
Out[21]: torch.Size([2, 1])
In [22]: a.shape
Out[22]: torch.Size([2])
In [23]: a.unsqueeze(-1)
Out[23]:
tensor([[1.2000],
[2.3000]])
In [24]: a
Out[24]: tensor([1.2000, 2.3000])
In [25]: a.unsqueeze(0)
Out[25]: tensor([[1.2000, 2.3000]])
In [26]: a.unsqueeze(0).shape
Out[26]: torch.Size([1, 2])
使用squeeze删除维度
如删除的维度长度不为1 则返回的是原数据
In [27]: b = torch.rand(1,32,1,1)
In [28]: b.shape
Out[28]: torch.Size([1, 32, 1, 1])
In [29]: b.squeeze().shape
Out[29]: torch.Size([32])
In [30]: b.squeeze(0).shape
Out[30]: torch.Size([32, 1, 1])
#删除第一个维度
In [31]: b.squeeze(-1).shape
Out[31]: torch.Size([1, 32, 1])
#删除最后一个维度
In [32]: b.squeeze(1).shape
Out[32]: torch.Size([1, 32, 1, 1])
#删除第二个维度 因为地二个维度不是1 所以保持不变
In [34]: b.squeeze(-4).shape
Out[34]: torch.Size([32, 1, 1])
#删除倒数第四个维度 也就是删除第一个维度
维度扩展----Expand/repeat
区别:expand维度扩展在有需要的时候增加数据,不增加数据时只是改变数据的理解方式,从而节约了内存(推荐操作)
repeat增加了数据 是通过拷贝复制数据
使用expand扩展维度
Expand使用时需要具备两个条件:
1.扩展后的tensor必须和扩展前的tensor维度一致
2.要扩展的维度长度只能为1 否则无法扩展
In [35]: a = torch.rand(4,32,14,14)
In [36]: b = torch.rand(1,32,1,1)
#把b扩展成和a格式一样
In [37]: a.shape
Out[37]: torch.Size([4, 32, 14, 14])
In [38]: b.shape
Out[38]: torch.Size([1, 32, 1, 1])
In [39]: b.expand(4,32,14,14).shape#维度扩展,扩展的维度长度都是1
Out[39]: torch.Size([4, 32, 14, 14])
In [40]: b.expand(-1,32,-1,-1).shape#使用-1 表示原维度长度不变
Out[40]: torch.Size([1, 32, 1, 1])
In [41]: b.expand(-1,32,-1,-4).shape#不会报错 但是这样的做法错误 有bug
Out[41]: torch.Size([1, 32, 1, -4])
条件一可以不具备,可以利用expand扩展出新的维度,例子:
import torch
x = torch.tensor([1, 2, 3])
y = x.expand(2, 3)
print(y)
结果为:
tensor([[1, 2, 3],
[1, 2, 3]])
print(x)
结果为:
tensor([1, 2, 3])
使用repeat扩展维度(不建议使用)
与expand的不同,Expand中传入的参数是扩展后维度的长度,而repeat中传入的参数是要拷贝的次数
In [42]: b.shape
Out[42]: torch.Size([1, 32, 1, 1])
In [43]: b.repeat(4,1,1,1).shape#dim0复制4次 其他维度拷贝一次
Out[43]: torch.Size([4, 32, 1, 1])
In [44]: b.repeat(4,32,1,1).shape#dim0复制4次,dim1复制32次
Out[44]: torch.Size([4, 1024, 1, 1])
In [45]: b.repeat(4,1,32,32).shape#dim0复制4次,dim2复制32次,dim3复制32次
Out[45]: torch.Size([4, 32, 32, 32])
维度转置----transpose/t/permute
使用t进行维度转置
该函数只能用于两个维度矩阵类型的数据进行转置 两维度以上的数据会报错
使用transpose进行转置
代码中使用contiguous函数的原因是因为,在维度重塑后,存储的数据由连续数据变为了不连续数据,因此需要通过该函数转换成连续数据,这与pytorch中数据存储方式有关,可以参考链接
In [15]: a.shape
Out[15]: torch.Size([4, 3, 32, 32])
#[b,c,h,w]
In [16]: a1 = a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)
#transpose(1,3)交换dim1和dim3 变成[b,w,h,c]==>[4,32,32,3]
#view(4,3*32*32) 变成 [b,w*h*c]
#view(4,3,32,32) 变成 [b,c,w,h]格式与源数据一致 但是顺序已经打乱了 造成数据污染
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-16-262e84a7fdcb> in <module>
----> 1 a1 = a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)
RuntimeError: invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view(). at c:\a\w\1\s\windows\pytorch\aten\src\th\generic/THTensor.cpp:213
In [17]: a1 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,3,32,32)
#transpose(1,3)交换dim1和dim3 变成[b,w,h,c]==>[4,32,32,3]
#view(4,3*32*32) 变成 [b,w*h*c]
#view(4,3,32,32) 变成 [b,c,w,h]格式与源数据一致 但是顺序已经打乱了 造成数据污染
In [18]: a2 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3)
#transpose(1,3)交换dim1和dim3 变成[b,w,h,c]==>[4,32,32,3]
#view(4,3*32*32) 变成 [b,w*h*c]
#view(4,3,32,32) 变成 [b,w,h,c]
#transpose(1,3) 变成 [b,c,h,w]格式与源数据一致 顺序也一致
In [19]: a1.shape,a2.shape
Out[19]: (torch.Size([4, 3, 32, 32]), torch.Size([4, 3, 32, 32]))
In [20]: torch.all(torch.eq(a,a1))
Out[20]: tensor(0, dtype=torch.uint8)
#所有数据相同返回1 否则返回0
In [21]: torch.all(torch.eq(a,a2))
Out[21]: tensor(1, dtype=torch.uint8)
使用permute进行转置
#分别使用transpose和permute 把[b,c,h,w]==>[b,h,w,c]
In [22]: a = torch.rand(4,3,28,28)
In [23]: a.transpose(1,3).shape
Out[23]: torch.Size([4, 28, 28, 3])
In [24]: b = torch.rand(4,3,28,32)
In [25]: b.transpose(1,3).shape
Out[25]: torch.Size([4, 32, 28, 3])
#[b,c,h,w]==>[b,w,h,c]
In [26]: b.transpose(1,3).transpose(1,2).shape
Out[26]: torch.Size([4, 28, 32, 3])
#[b,c,h,w]==>[b,w,h,c]==>[b,h,w,c]
In [27]: b.permute(0,2,3,1).shape
Out[27]: torch.Size([4, 28, 32, 3])
#[b,c,h,w]==>[b,h,w,c] 直接根据下标交换维度