tensor.expand()
>>> import torch
>>> a = torch.tensor([[1],[2],[3]])
>>> print(a.size())
torch.Size([3, 1])
>>> a.expand(3,4)
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
>>> a
tensor([[1],
[2],
[3]])
>>> a.expand(-1, 4) # -1 means not changing the size of that dimension
tensor([[ 1, 1, 1, 1],
[ 2, 2, 2, 2],
[ 3, 3, 3, 3]])
tensor.expand_as()
>>> b=torch.tensor([[4,7],[5,8],[6,9]])
>>> b
tensor([[ 4, 7],
[ 5, 8],
[ 6, 9]])
>>> print(b.size())
torch.Size([3, 2])
>>> a.expand_as(b)
tensor([[ 1, 1],
[ 2, 2],
[ 3, 3]])
>>> a
tensor([[ 1],
[ 2],
[ 3]])
总结:
tensor.expand()函数是把一个tensor变形为尺寸是括号内所给大小的tensor;
tensor.expand_as()函数是把一个tensor变成和函数括号内所给tensor同样形状的tensor;
expand括号里为size,expand_as括号里为其他tensor。
两个函数均是原来的tensor和变形后tensor不共享内存 的,如需使用变形后的,需重新赋值。