张量的操作确实有技巧,pytorch中提供了一些相关的操作,在编程的时候很有用。涉及的方法见下面的目录:
一、拼接张量
1、torch.cat()
torch.cat(tensors, dim=0, out=None) → Tensor
在指定的维度dim上对序列 seq 进行连接操作,注意:张量可以有多个,不一定只是两个。
参数:
tensors (sequence of Tensors) - 相同类型的张量的任何python序列。所提供的非空张量必须具有相同的形状,除了cat尺寸。
dim (int, optional) - 沿着此维度连接张量
out (Tensor, optional) - 输出参数
例子:
>>> x = torch.Tensor([[1, 2, 3],[4, 5, 6]])
>>> x
tensor([[1., 2., 3.],
[4., 5., 6.]])
>>> x.shape # 打印原始形状
torch.Size([2, 3])
>>> y = torch.cat((x, x, x), 0) # 在0维进行拼接
tensor([[1., 2., 3.],
[4., 5., 6.],
[1., 2., 3.],
[4., 5., 6.],
[1., 2., 3.],
[4., 5., 6.]])
>>> y.shape # 打印形状,发现0维形状由 2-->6
torch.Size([6, 3])
>>> z = torch.cat((x, x, x), 1) # 在1维进行拼接
tensor([[1., 2., 3., 1., 2., 3., 1., 2., 3.],
[4., 5., 6., 4., 5., 6., 4., 5., 6.]])
>>> z.shape # 打印形状,发现1维形状由 3-->9
torch.Size([2, 9])
2、torch.stack()
torch.stack(tensors, dim=0, out=None) → Tensor
沿着一个新的维数串联张量序列,所有的张量必须是相同的大小。这个和cat()的不同之处在于新增加了一个维度,新增的维度的位置就是dim。
参数:
tensors (sequence of Tensors) - 连接的张量序列
dim (int, optional) - 维插入。必须在0和连接张量的维数之间(包括)
out (Tensor, optional) - 输出参数
例子:
>>> a = torch.IntTensor([[1,2,3],[11,22,33]])
>>>> a
tensor([[ 1, 2, 3],
[11, 22, 33]], dtype=torch.int32)
>>> a.shape
torch.Size([2, 3])
>>> b= torch.IntTensor([[4,5,6],[44,55,66]])
>>> b
tensor([[ 4, 5, 6],
[44, 55, 66]], dtype=torch.int32)
>>> b.shape
torch.Size([2, 3])
>>> c=torch.stack([a,b],0)
>>> c
tensor([[[ 1, 2, 3],
[11, 22, 33]],
[[ 4, 5, 6],
[44, 55, 66]]], dtype=torch.int32)
>>> c.shape # 在第0维将这两个张量进行拼接,而原始的0维度向后移动
torch.Size([2, 2,<