维度变换
维度变换相关函数主要有 torch.reshape(或者调用张量的view方法), torch.squeeze, torch.unsqueeze, torch.transpose
torch.reshape 可以改变张量的形状。
torch.squeeze 可以减少维度。
torch.unsqueeze 可以增加维度。
torch.transpose 可以交换维度。
# 张量的view方法有时候会调用失败,可以使用reshape方法。
torch.manual_seed(0)
minval,maxval = 0,255
a = (minval + (maxval-minval)*torch.rand([1,3,3,2])).int()
print(a.shape)
print(a)
torch.Size([1, 3, 3, 2])
tensor([[[[126, 195],
[ 22, 33],
[ 78, 161]],
[[124, 228],
[116, 161],
[ 88, 102]],
[[ 5, 43],
[ 74, 132],
[177, 204]]]], dtype=torch.int32)
# 改成 (3,6)形状的张量
b = a.view([3,6]) #torch.reshape(a,[3,6])
print(b.shape)
print(b)
torch.Size([3, 6])
tensor([[126, 195, 22, 33, 78, 161],
[124, 228, 116, 161, 88, 102],
[ 5, 43, 74, 132, 177, 204]], dtype=torch.int32)
# 改回成 [1,3,3,2] 形状的张量
c = torch.reshape(b,[1,3,3,2]) # b.view([1,3,3,2])
print(c)
tensor([[[[126, 195],
[ 22, 33],
[ 78, 161]],
[[124, 228],
[116, 161],
[ 88, 102]],
[[ 5, 43],
[ 74, 132],
[177, 204]]]], dtype=torch.int32)
如果张量在某个维度上只有一个元素,利用torch.squeeze可以消除这个维度。
torch.unsqueeze的作用和torch.squeeze的作用相反。
a = torch.tensor([[1.0,2.0]])
s = torch.squeeze(a)
print(a)
print(s)
print(a.shape)
print(s.shape)
tensor([[1., 2.]])
tensor([1., 2.])
torch.Size([1, 2])
torch.Size([2])
#在第0维插入长度为1的一个维度
d = torch.unsqueeze(s,axis=0)
print(s)
print(d)
print(s.shape)
print(d.shape)
tensor([1., 2.])
tensor([[1., 2.]])
torch.Size([2])
torch.Size([1, 2])
torch.transpose可以交换张量的维度,torch.transpose常用于图片存储格式的变换上。
如果是二维的矩阵,通常会调用矩阵的转置方法 matrix.t(),等价于 torch.transpose(matrix,0,1)。
minval=0
maxval=255
# Batch,Height,Width,Channel
data = torch.floor(minval + (maxval-minval)*torch.rand([100,256,256,4])).int()
print(data.shape)
# 转换成 Pytorch默认的图片格式 Batch,Channel,Height,Width
# 需要交换两次
data_t = torch.transpose(torch.transpose(data,1,2),1,3)
print(data_t.shape)
torch.Size([100, 256, 256, 4])
torch.Size([100, 4, 256, 256])
matrix = torch.tensor([[1,2,3],[4,5,6]])
print(matrix)
print(matrix.t()) #等价于torch.transpose(matrix,0,1)
tensor([[1, 2, 3],
[4, 5, 6]])
tensor([[1, 4],
[2, 5],
[3, 6]])