pytorch中,repeat()这个函数的功能是很容易理解的,但是函数在不同参数下的执行逻辑通常还是会让初次接触的同学有点疑惑,今天仔细想了一下,感觉可以以这样的方式去理解:
import torch x = torch.tensor([1,2,3]) #将一维度的x扩展到三维 xx = x.repeat(4,2,1) /** 扩展步骤如下(倒着执行): 1 最后一个维度1:此时将[1,2,3]中的数字直接重复1次,得到[1,2,3],保持没变 2 倒数第二个维度2:先将上一步骤的结果增加一个维度,得到[[1,2,3]],然后将最外层中括号中的整体重复2次,得到[[1,2,3],[1,2,3]] 3 倒数第三个维度4:先将上一步骤的结果增加一个维度,得到[[[1,2,3],[1,2,3]]],然后将最外层中括号中的整体重复4次,得到[[[1,2,3],[1,2,3]],[[1,2,3],[1,2,3]],[[1,2,3],[1,2,3]],[[1,2,3],[1,2,3]]] 4 三个维度扩展结束,得到结果。 **/
测试代码如下:
import torch
x = torch.tensor([1,2,3])
x1 = x.repeat(4)
print("x1:\n",x1)
x2 = x.repeat(4,1)
print("x2:\n"x2)
x3 = x.repeat(4,2)
print("x3:\n"x3)
x4 = x.repeat(4,2,1)
print("x4:\n"x4)
测试代码结果:
x1:
tensor([1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3])
x2:
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
x3:
tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]])
x4:
tensor([[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]]])