- 张量的基本信息
tensor = torch.randn(2,3,4)
print(tensor.type()) # 数据类型 torch.FloatTensor,是一个浮点型的张量
print(tensor.size()) # 张量的shape,是个元组 torch.Size([2, 3, 4])
print(tensor.dim()) # 维度的数量 3
- 张量的命名
程序中,一个好的命名可以便于其他人读懂代码,张量的命名也是如此。这样可以方便地使用维度的名字来做索引或其他操作,提高了可读性、易用性,防止程序出错,便于其他人阅读和修改。
在PyTorch 1.3之前,需要使用注释
Tensor[N, C, H, W]
images = torch.randn(32, 3, 56, 56)
images.sum(dim=1)
images.select(dim=1, index=0)
PyTorch 1.3之后
NCHW = [‘N’, ‘C’, ‘H’, ‘W’]
images = torch.randn(32, 3, 56, 56, names=NCHW)
images.sum(‘C’)
images.select(‘C’, index=0)
也可以这么设置
tensor = torch.rand(3,4,1,2,names=(‘C’, ‘N’, ‘H’, ‘W’))
使用align_to可以对维度方便地排序
tensor = tensor.align_to(‘N’, ‘C’, ‘H’, ‘W’)
4.张量数据类型转换
在Pytorch中,FloatTensor处理速度远远快于DoubleTensor,因此默认采用FloatTensor,也可以通过转换变成其他类型的数据。
设置默认类型
torch.set_default_tensor_type(torch.FloatTensor)
类型转换
tensor = tensor.cuda()
tensor = tensor.cpu()
tensor = tensor.float()
tensor = tensor.long()
4.1 torch.Tensor与np.ndarray转换
除了CharTensor类型外,其他所有CPU上的张量都支持转换为numpy格式,当然也可以再转换回来。
ndarray = tensor.cpu().numpy()
tensor = torch.from_numpy(ndarray).float()
tensor = torch.from_numpy(ndarray.copy()).float() # If ndarray has negative stride.
4.2 torch.tensor与PIL.Image转换
在Pytorch中,张量默认采用[N, C, H, W]的顺序,并且数据范围在[0,1],有时候处理数据时需要进行转置和规范化。
torch.Tensor -> PIL.Image
image = PIL.Image.fromarray(torch.clamp(tensor*255, min=0, max=255).byte().permute(1,2,0).cpu().numpy())
image = torchvision.transforms.functional.to_pil_image(tensor)
PIL.Image -> torch.Tensor
path = r’./figure.jpg’
tensor = torch.from_numpy(np.asarray(PIL.Image.open(path))).permute(2,0,1).float() / 255
tensor = torchvision.transforms.functional.to_tensor(PIL.Image.open(path)) # Equivalently way
5.张量的常用操作
5.1 矩阵乘法
Matrix multiplcation: (mn) * (np) * -> (m*p).
result = torch.mm(tensor1, tensor2)
Batch matrix multiplication: (bmn) * (bnp) -> (bmp)
result = torch.bmm(tensor1, tensor2)
Element-wise multiplication.
result = tensor1 * tensor2
5.2 计算两组数据之间的两两欧式距离
dist = torch.sqrt(torch.sum((X1[:,None,:] - X2) ** 2, dim=2))
5.3 张量形变
将卷积层输入全连接层的情况时,通常需要对张量做形变处理如.view()和.reshape()等,但是相比torch.view,torch.reshape可以自动处理输入张量不连续的情况。
tensor = torch.rand(2,3,4)
shape = (6, 4)
tensor = torch.reshape(tensor, shape)
5.4 打乱顺序
tensor = tensor[torch.randperm(tensor.size(0))] # 打乱第一个维度
5.5 水平翻转
Pytorch不支持tensor[::-1]这样的负步长操作,水平翻转可以通过张量索引实现。
假设张量的维度为[N, D, H, W].
tensor = tensor[:, :, :, torch.arange(tensor.size(3) - 1, -1, -1).long()]
5.6 张量复制
tensor.clone()
tensor.detach()
tensor.detach.clone()
5.7 张量拼接
torch.cat和torch.stack的区别在于torch.cat沿着给定的维度拼接,而torch.stack会新增一维。当参数是3个10x5的张量,torch.cat的结果是30x5的张量,而torch.stack的结果是3x10x5的张量。
tensor = torch.cat(list_of_tensors, dim=0)
tensor = torch.stack(list_of_tensors, dim=0)
5.8 将整数标签转为one-hot编码
Pytorch的标记默认从0开始,转换为one-hot编码在数据处理时也经常用到。
tensor = torch.tensor([0, 2, 1, 3])
N = tensor.size(0)
num_classes = 4
one_hot = torch.zeros(N, num_classes).long()
one_hot.scatter_(dim=1, index=torch.unsqueeze(tensor, dim=1), src=torch.ones(N, num_classes).long())
5.9 得到非零元素
torch.nonzero(tensor) # index of non-zero elements,索引非零元素
torch.nonzero(tensor==0) # index of zero elements,索引零元素
torch.nonzero(tensor).size(0) # number of non-zero elements,非零元素的个数
torch.nonzero(tensor == 0).size(0) # number of zero elements,零元素的个数
5.10 张量扩展
将64512的张量扩展为6451277的张量
tensor = torch.rand(64,512)
torch.reshape(tensor, (64, 512, 1, 1)).expand(64, 512, 7, 7)