对于 PyTorch 的基本数据对象 Tensor (张量),在处理问题时,需要经常改变数据的维度,以便于后期的计算和进一步处理,本文旨在列举一些维度变换的方法并举例,方便大家查看。
维度查看:torch.Tensor.size()
查看当前 tensor 的维度
举个例子:
>>> import torch
>>> a = torch.Tensor([[[1, 2], [3, 4], [5, 6]]])
>>> a.size()
torch.Size([1, 3, 2])
张量变形:torch.Tensor.view(*args) → Tensor
返回一个有相同数据但大小不同的 tensor。 返回的 tensor 必须有与原 tensor 相同的数据和相同数目的元素,但可以有不同的大小。一个 tensor 必须是连续的 contiguous() 才能被查看。
举个例子:
>>> x = torch.randn(2, 9)
>>> x.size()
torch.Size([2, 9])
>>> x
tensor([[-1.6833, -0.4100, -1.5534, -0.6229, -1.0310, -0.8038, 0.5166, 0.9774,
0.3455],
[-0.2306, 0.4217, 1.2874, -0.3618, 1.7872, -0.9012, 0.8073, -1.1238,
-0.3405]])
>>> y = x.view(3, 6)
>>> y.size()
torch.Size([3, 6])
>>> y
tensor([[-1.6833, -0.4100, -1.5534, -0.6229, -1.0310, -0.8038],
[ 0.5166, 0.9774, 0.3455, -0.2306, 0.4217, 1.2874],
[-0.3618, 1.7872, -0.9012, 0.8073, -1.1238, -0.3405]])
>>> z = x.view(2, 3, 3)
>>> z.size()
torch.Size([2, 3, 3])
>>> z
tensor([[[-1.6833, -0.4100, -1.5534],
[-0.6229, -1.0310, -0.8038],
[ 0.5166, 0.9774, 0.3455]],
[[-0.2306, 0.4217, 1.2874],
[-0.3618, 1.7872, -0.9012],
[ 0.807