pytorch基本语法学习(二)
经过pytorch基本语法学习(一)的学习,已经对tensor有的基本的认识,在这节的学习中,将是非常非常重要的一节,因为在深度学习图像领域是经常使用到的知识。整装待发,进入学习状态…
1. tensor维度变化 view
a = torch.FloatTensor(2,3,4)
print('a = ', a) #a = tensor([[[ 1.4013e-45, 0.0000e+00, 0.0000e+00, 0.0000e+00],
# [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
# [-1.2391e-11, 3.0756e-41, 0.0000e+00, 0.0000e+00]],
# [[ 1.4013e-45, 0.0000e+00, 1.6816e-44, 0.0000e+00],
# [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
# [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]])
print(a.shape) # torch.Size([2, 3, 4])
print(a.view(2,3*4)) #tensor([[ 1.4013e-45, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
# 0.0000e+00, 0.0000e+00, 0.0000e+00, -1.2391e-11, 3.0756e-41,
# 0.0000e+00, 0.0000e+00],
# [ 1.4013e-45, 0.0000e+00, 1.6816e-44, 0.0000e+00, 0.0000e+00,
# 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
# 0.0000e+00, 0.0000e+00]])
print(a.view(2,3*4).shape) # torch.Size([2, 12])
print(a.view(2*3,4)) # tensor([[ 1.4013e-45, 0.0000e+00, 0.0000e+00, 0.0000e+00],
# [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
# [-1.2391e-11, 3.0756e-41, 0.0000e+00, 0.0000e+00],
# [ 1.4013e-45, 0.0000e+00, 1.6816e-44, 0.0000e+00],
# [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
# [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]])
print(a.view(2*3,4).shape) # torch.Size([6, 4])
print(a.view(2*3*4)) # tensor([ 1.4013e-45, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
# 0.0000e+00, 0.0000e+00, 0.0000e+00, -1.2391e-11, 3.0756e-41,
# 0.0000e+00, 0.0000e+00, 1.4013e-45, 0.0000e+00, 1.6816e-44,
# 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
# 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])
print(a.view(2*3*4).shape) # torch.Size([24])
b = a.view(2*3*4)
print(b.view(2,3,4)) # tensor([[[ 1.4013e-45, 0.0000e+00, 0.0000e+00, 0.0000e+00],
# [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
# [-1.2391e-11, 3.0756e-41, 0.0000e+00, 0.0000e+00]],
# [[ 1.4013e-45, 0.0000e+00, 1.6816e-44, 0.0000e+00],
# [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
# [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]])
print(b.view(2,3,4).shape) # torch.Size([2, 3, 4])
print(b.view(3,4,2)) # tensor([[[ 1.4013e-45, 0.0000e+00],
# [ 0.0000e+00, 0.0000e+00],
# [ 0.0000e+00, 0.0000e+00],
# [ 0.0000e+00, 0.0000e+00]],
# [[-1.2391e-11, 3.0756e-41],
# [ 0.0000e+00, 0.0000e+00],
# [ 1.4013e-45, 0.0000e+00],
# [ 1.6816e-44, 0.0000e+00]],
# [[ 0.0000e+00, 0.0000e+00],
# [ 0.0000e+00, 0.0000e+00],
# [ 0.0000e+00, 0.0000e+00],
# [ 0.0000e+00, 0.0000e+00]]])
print(b.view(3,4,2).shape) # torch.Size([3, 4, 2])
2. tensor维度变化 squeeze&unsqueeze:压缩或扩展维度,注意只能将dim=1的位置压缩,且扩展也是在指定的位置上扩展为dim=1
a = torch.FloatTensor(2, 3, 4)
print(a.shape) # torch.Size([2, 3, 4])
print(a.unsqueeze(0).shape) # torch.Size([1, 2, 3, 4])
print(a.unsqueeze(1).shape) # torch.Size([2, 1, 3, 4])
print(a.unsqueeze(2).shape) # torch.Size([2, 3, 1, 4])
print(a.unsqueeze(3).shape) # torch.Size([2, 3, 4, 1])
b = torch.FloatTensor(1, 2, 1, 1)
print(b.shape) # torch.Size([1, 2, 1, 1])
print(b.squeeze(0).shape) # torch.Size([2, 1, 1])
print(b.squeeze(2).shape) # torch.Size([1, 2, 1])
print(b.squeeze(3).shape) # torch.Size([1, 2, 1])
3. tensor维度变化 expand & repeat:如果没有特殊需求本人建议使用expend,因为repeat会占用额外内存,注意观察expand & repeat输出shape的区别
b = torch.FloatTensor(1, 2, 1, 1)
print(b.shape) # torch.Size([1, 2, 1, 1])
print(b.expand(4,2,5,6).shape) # torch.Size([4, 2, 5, 6])
print(b.repeat(4,2,1,1).shape) # torch.Size([4, 4, 1, 1])
print(b.repeat(4,1,2,2).shape) # torch.Size([4, 2, 2, 2])
4. tensor维度变化transpose & permute
#注意观察b,c的转换,虽然a,b,c的shape是相同的,但是b与a并不相同,这也是容易出现错误的地方。
a = torch.FloatTensor(4, 3, 32, 32)
print(a.transpose(1,3).shape) # torch.Size([4, 32, 32, 3]) 维度1和3换位置
b = a.transpose(1,3).contiguous().view(4,3*32*32).view(4, 3, 32, 32)
print(b.shape) # torch.Size([4, 3, 32, 32])
c = a.transpose(1,3).contiguous().view(4,3*32*32).view(4, 32, 32, 3).transpose(1, 3)
print(c.shape) # torch.Size([4, 3, 32, 32])
print(torch.all(torch.eq(a,b))) # tensor(False) 判断a,b是否相等
print(torch.all(torch.eq(a,c))) # tensor(True) 判断a,c是否相等
c = torch.FloatTensor(4, 3, 28, 32)
print(c.transpose(1,3).shape) # torch.Size([4, 32, 28, 3])
print(c.transpose(1,3).transpose(1,2).shape) # torch.Size([4, 28, 32, 3])
print(c.permute(0,2,3,1).shape) # torch.Size([4, 28, 32, 3])
print(torch.all(torch.eq(c.permute(0,2,3,1),c.transpose(1,3).transpose(1,2)))) # tensor(True)
在下一个学习中主要练习一下tensor的拼接与拆分,数学运算和统计。