repeat()将原矩阵进行广播,repeat()传入的参数为每个维度复制的次数,复制时从最右边的维度开始
‘’’
扩展步骤如下(倒着执行):
1 最后一个维度1:此时将[1, 2, 3]中的数字直接重复1次,得到[1, 2, 3],保持没变
2 倒数第二个维度2:先将上一步骤的结果增加一个维度,得到[[1, 2, 3]],然后将最外层中括号中的整体重复2次,得到[[1, 2, 3], [1, 2, 3]]
3 倒数第三个维度4:先将上一步骤的结果增加一个维度,得到[[[1, 2, 3], [1, 2, 3]]],然后将最外层中括号中的整体重复4次,
得到[[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]]
4 三个维度扩展结束,得到结果。
‘’’
repeat()参数的数量,即为复制后的维数,所以复制后的维数不能小于复制前的维数。
import torch
x = torch.tensor([[1, 2, 3],[4,5,6]])
# x1 = x.repeat(4) # x若为2维,则此语句报错。
# print(f"{x1}:\n")
x2 = x.repeat(4, 1)
print(f"{x2}:\n")
x3 = x.repeat(4, 2)
print(f"{x3}:\n")
x4 = x.repeat(4, 2, 1)
print(f"{x4}:\n")
https://blog.csdn.net/weixin_41041772/article/details/123296659