目录
repeat函数
import torch
a = torch.tensor([[1, 2, 3],[1, 2, 3]])
print(a)
b = torch.tensor([[2, 2, 2], [3, 3, 3], [3, 3, 3], [3, 3, 3]])
print(a.shape,b.size())
c = a.repeat(2,1)
print(c)
print(c.size())
expand函数:
>>> import torch
>>> a=torch.tensor([[2],[3],[4]])
>>> print(a.size())
torch.Size([3, 1])
>>> a.expand(3,2)
tensor([[2, 2],
[3, 3],
[4, 4]])
>>> a
tensor([[2],
[3],
[4]])
expand
原来维度是2*3
变成 4*2*3是可以的,结果比原来多一个维度
import torch
a = torch.tensor([[1, 2, 3],[4, 5, 6]])
print(a)
print(a.size())
# b = torch.tensor([[2, 2, 2], [3, 3, 3], [3, 3, 3], [3, 3, 3]])
# print(a.shape,