PyTorch中的repeat()函数可以对张量进行重复扩充。
repeat()中的参数个数需 >= tensor 维度
举例证明:
>>> abb = torch.rand([2,1,2,3])
>>> print("abb: ",abb)
abb: tensor([[[[0.1103, 0.8384, 0.8689],
[0.2282, 0.9317, 0.4930]]],
[[[0.8018, 0.8036, 0.1346],
[0.9017, 0.4192, 0.6952]]]])
>>> print("abb.shape: ",abb.shape)
abb.shape: torch.Size([2, 1, 2, 3])
>>> bb = abb.repeat(1,2)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
看看 错误原因:RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor.
当参数只有两个时:(行的倍数,列的倍数), 此时通道数倍数为1,表示不重复扩充。
当参数有三个时:(通道数的倍数,行的倍数,列的倍数)。
直接将repeat参数对应乘上原tensor的shape就得到最终的shape,如:
shape(2,3) * repeat(1,2) = shape(2,6) //channel = 1
shape(2,3) * repeat(2,2,1) = shape(2,4,3) // shape(2,3) 可看作shape(1,2.3)
>>> bb = abb.repeat(2,2,2,3)
>>> bb
tensor([[[[0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
0.8689],
[0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
0.4930],
[0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
0.8689],
[0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
0.4930]],
[[0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
0.8689],
[0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
0.4930],
[0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
0.8689],
[0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
0.4930]]],
[[[0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
0.1346],
[0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
0.6952],
[0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
0.1346],
[0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
0.6952]],
[[0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
0.1346],
[0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
0.6952],
[0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
0.1346],
[0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
0.6952]]],
[[[0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
0.8689],
[0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
0.4930],
[0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
0.8689],
[0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
0.4930]],
[[0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
0.8689],
[0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
0.4930],
[0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
0.8689],
[0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
0.4930]]],
[[[0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
0.1346],
[0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
0.6952],
[0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
0.1346],
[0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
0.6952]],
[[0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
0.1346],
[0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
0.6952],
[0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
0.1346],
[0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
0.6952]]]])
>>> print("bb.shape: ",bb.shape)
bb.shape: torch.Size([4, 2, 4, 9])
>>>