PyTorch中的repeat()函数可以对张量进行复制。
import torch
a = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])
print(a.shape)
print(a)
# 将第0维复制3次
b = a.repeat(3, 1, 1)
print(b.shape)
print(b)
PyTorch中的repeat()函数可以对张量进行复制。
import torch
a = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])
print(a.shape)
print(a)
# 将第0维复制3次
b = a.repeat(3, 1, 1)
print(b.shape)
print(b)