view/reshape 功能完全一样。对一个张量进行维度变换,但是最终的总数量必须保持不变
import torch
import numpy as np
a = torch.rand(4,2,28,28) # 四维张量
print(a.shape) # 原始的张量
print(a.view(8,28*28).shape) # 维度转换成一维,总大小不变
print(a.view(-1,8)) # -1 表示任意维度(pytorch根据,后面的维度自己推导,如总维度是 28*28*8,此时-1代表的维度就是28*28*8 / 8 =28*28)
torch.Size([4, 2, 28, 28])
torch.Size([8, 784])
torch.Size([784, 8])
suqeeze/unsqueeze 维度的压缩与扩展
unsuqeeze
import torch
import numpy as np
a = torch.rand(4,2,28,28) # 四维张量
print("a 的初始维度是:{}".format(a.shape)) # 原始的张量
print("a 增加一个维度是:{}".format(a.unsqueeze(0).shape)) # 最前面增加一个维度
print("a 增加一个维度是:{}".format(a.unsqueeze(1).shape)) # 在原来第二个维度前增加一个维度
print("a 增加一个维度是:{}".format(a.unsqueeze(-1).shape)) # 最后面增加一个维度
a 的初始维度是:torch.Size([4, 2, 28, 28])
a 增加一个维度是:torch.Size([1, 4, 2, 28, 28])
a 增加一个维度是:torch.Size([4, 1, 2, 28, 28])
a 增加一个维度是:torch.Size([4, 2, 28, 28, 1])
suqeeze
import torch
import numpy as np
a = torch.rand(4,2,28,28,1,1) # 四维张量
print("a 的初始维度是:{}".format(a.shape)) # 原始的张量
print("a 删除维度为1的维度是:{}".format(a.squeeze().shape)) # 删除所有维度为1的维度
print("a 删除指定位置为1的维度是:{}".format(a.squeeze(4).shape)) #
a 的初始维度是:torch.Size([4, 2, 28, 28, 1, 1])
a 删除维度为1的维度是:torch.Size([4, 2, 28, 28])
a 删除指定位置为1的维度是:torch.Size([4, 2, 28, 28, 1])
转置
转置 .t() (二维)
import torch
import numpy as np
a = torch.rand(4,2) # 四维张量
print("a 转置后是:{}".format(a.t().shape))
a 转置后是:torch.Size([2, 4])
transpose 两个维度相互交换
import torch
import numpy as np
a = torch.rand(4,3,32,32) # 四维张量
#transpose 后需要接 contiguous 保证数据在内存的连续性
a1 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,3,32,32)
a2 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3)
# torch.all 判断所有元素 torch.eq 判断两个张量是否相等
print("a1 shape : {} a2 shape : {}".format(a1.shape,a2.shape))
# 判断变换后 a 与 a1 是否相同
print("a1 与 a 是否相同:{}".format(torch.all(torch.eq(a,a1))))
# 判断变换后 a 与 a2 是否相同
print("a2 与 a 是否相同:{}".format(torch.all(torch.eq(a,a2))))
a1 shape : torch.Size([4, 3, 32, 32]) a2 shape : torch.Size([4, 3, 32, 32])
a1 与 a 是否相同:False
a2 与 a 是否相同:True
permute 任意维度的交换
import torch
import numpy as np
a = torch.rand(4,3,32,32) # 四维张量
#permute 后需要接 contiguous 保证数据在内存的连续性
a1 = a.permute(0,3,2,1).contiguous() # 按各维度的索引进行排列
print("a1 shape : {} ".format(a1.shape))
a1 shape : torch.Size([4, 32, 32, 3])