Pytorch 的repeat函数
看代码的时候,对于repeat函数的参数产生了疑问,再查阅资料的情况下算是搞清楚了,这篇博客作为一个学习笔记记录一下。
import torch
data = torch.tensor([1, 2, 3])
data = data.repeat(2, 3)
print(data)
print(data.size())
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3, 1, 2, 3]])
torch.Size([2, 9])
我们新建一个尺寸为1x3的tensor,使用repeat函数如上图所示,函数将原tensor横向复制了两次,纵向复制了三次,新的size为2x9。
import torch
data = torch.tensor([1, 2, 3])
data = data.repeat(2, 1, 3)
print(data)
print(data.size())
tensor([[[1, 2, 3, 1, 2, 3, 1, 2, 3]],
[[1, 2, 3, 1, 2, 3, 1, 2, 3]]])
torch.Size([2, 1, 9])
当repeat函数中带有三个参数时,第一个参数为扩充的通道数,新的size为2x1x9,纵向复制了三次,通道数扩充为2。