索引操作
- 索引操作
a = torch.IntTensor(4,3,28,28)
# 对第一张和第二张图片索引
a.index_select(0,torch.tensor([0,2]))
#第一个参数为哪一个维度 第二个参数要求为tensor类型的list
# ...代表任意维度
a[...].shape
# torch.Size([4, 3, 28, 28])
a[0,...].shape
# torch.Size([3, 28, 28]) 第一张图片
# select by mask
x = torch.randn(3,4)
# tensor([[-1.9093, 0.2825, 0.7315, 1.1656],
[-0.4558, 1.3926, -0.9906, 0.4290],
[ 1.8861, 0.8970, 0.4231, 0.8157]])
mask = x.ge(0.5) # 标记元素大于0.5的位置为True,其他位置为0
# tensor([[False, False, True, True],
[False, True, False, False],
[ True, True, False, True]])
torch.masked_select(x,mask)
# tensor([0.7315, 1.1656, 1.3926, 1.8861, 0.8970, 0.8157])
# select by flatten index,把数据打平然后索引
src = torch.tensor([[1,2,3],[4,5,6]])
torch.take(src,torch.tensor([0,2,4]))
# tensor([1, 3, 5])和index_select不同,它第一个参数是维度的选择
维度变换
-
维度变换
view reshape
a = torch.rand(4,1,28,28) a.shape # torch.Size([4, 1, 28, 28]) a.view(4,28*28) a.view(4,28*28).shape # torch.Size([4, 784]) # 可以实现维度的任意变化,注意数据的意义
squeeze(压缩) /unsqueeze(展开)
# unsqueeze,正数在索引左侧插入,负数在在索引右侧插 入一个维度 a.shape # torch.Size([4, 1, 28, 28]) a.unsqueeze(0).shape # torch.Size([1, 4, 1, 28, 28]) a.unsqueeze(-1).shape # torch.Size([4, 1, 28, 28, 1]) a = torch.tensor([1.2,2.3]) a.unsqueeze(-1) # tensor([[1.2000], [2.3000]]) # 在每一个数据后加上一个维度 a.unaqueeze(0) # tensor([[1.2000, 2.3000]]) # 在每一个数据前加上一个维度
example
# 在图片的channel通道加上一个偏置 b = torch.rand(32) f = torch.rand(4,32,14,14) b = b.unsqueeze(1).unsqueeze(1).unsqueeze(0) b.shape # torch.Size([1, 32, 1, 1])
squeeze维度删减
# 不指定维度时,删除所有size为1的维度 b.shape # torch.Size([1, 32, 1, 1]) b.squeeze() # torch.Size([32]) # 指定维度,删除对应索引,若维度!=1,不删除也不报错 b.squeeze(0).shape # torch.Size([32, 1, 1]) b.squeeze(1).shape # torch.Size([1, 32, 1, 1]) 没有删除
维度扩展:expand不主动填充数据,推荐使用,参数就是变化后的维度;repeat主动复制数据,参数为复制的次数
# expand b.shape # torch.Size([1, 32, 1, 1]) b.expand(4,32,14,14).shape # torch.Size([4, 32, 14, 14]) b.expand(-1,32,14,14).shape # torch.Size([1, 32, 14, 14]) 参数-1表示保持原维度不变,懒得写了
# repeat b.shape # torch.Size([1, 32, 1, 1]) b.repeat(4,32,1,1).shape # torch.Size([4, 1024, 1, 1])
数据转置
-
数据转置
在使用view操作,会丢失数据的维度顺序关系,需要人为跟踪数据维度变化
# a.shape # torch.Size([4, 3, 32, 32]) [b,c,h,w] a1 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,3,32,32) # [b,c,w,h] a2 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3) # [b,c,h,w] torch.all(torch.eq(a,a2)) # tensor(True) torch.all(torch.eq(a,a1)) # tensor(False)
permute/要把[b,c,h,w]变为[b,h,w,c]需要两次transpose操作,也可以直接使用permute实现任意次维度交换,其实是内部调用若干次transpose操作
b = torch.rand(4,3,28,32) b.permute(0,2,3,1) # torch.Size([4, 28, 32, 3])