torch.Tensor
有两个实例方法可以用来扩展某维的数据的尺寸,分别是 repeat()
和 expand()
。
expand()
返回当前张量在某维扩展更大后的张量。按照指定size扩充。
扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),一个大小(size)等于1的维度扩展到更大的尺寸。
代码示例:
In [45]: x = torch.randn(1,3)
In [46]: x
Out[46]: tensor([[-1.1352, 0.3773, -0.2824]])
In [47]: x.expand(2, 3)
Out[47]:
tensor([[-1.1352, 0.3773, -0.2824],
[-1.1352, 0.3773, -0.2824]])
In [48]: x.expand(2, -1)
Out[48]:
tensor([[-1.1352, 0.3773, -0.2824],
[-1.1352, 0.3773, -0.2824]])
repeat()
沿着特定的维度重复这个张量,按照倍数扩充;和expand()
不同的是,这个函数拷贝张量的数据。
In [53]: x
Out[53]: tensor([[-1.1352, 0.3773, -0.2824]])
In [54]: x.shape
Out[54]: torch.Size([1, 3])
In [55]: x.repeat(2,3)
Out[55]:
tensor([[-1.1352, 0.3773, -0.2824, -1.1352, 0.3773, -0.2824, -1.1352, 0.3773,
-0.2824],
[-1.1352, 0.3773, -0.2824, -1.1352, 0.3773, -0.2824, -1.1352, 0.3773,
-0.2824]])
In [56]: x.repeat(2,3).shape
Out[56]: torch.Size([2, 9])