目录
在PyTorch中有两个函数可以用来扩展某一维度的张量,即 torch.expand() 和 torch.repeat()
1. torch.expand(*sizes)
【含义】将输入张量在大小为1的维度上进行拓展,并返回扩展更大后的张量
【参数】sizes的shape为torch.Size 或 int,指拓展后的维度, 当值为-1的时候,表示维度不变
import torch
if __name__ == '__main__':
x = torch.rand(1, 3)
y1 = x.expand(4, 3)
print(y1.shape) # torch.Size([4, 3])
y2 = x.expand(6, -1)
print(y2.shape) # torch.Size([6, 3])
2. torch.repeat(*sizes)
【含义】沿着特定维度扩展张量,并返回扩展后的张量
【参数】sizes的shape为torch.Size 或 int,指对当前维度扩展的倍数
import torch
if __name__ == '__main__':
x = torch.rand(2, 3)
y1 = x.repea