torch.expend_as
代码示例:
import torch
#1
x = torch.randn(2, 1, 1)#为1可以扩展为3和3
y = torch.randn(2, 3, 3)
x = x.expand_as(y)
print('x :', x.size())
>>> x : torch.Size([2, 3, 3])
#2
x = torch.randn(2, 2, 2)#为2不可以扩展为3和4
y = torch.randn(2, 3, 4)
x = x.expand_as(y)
print('x :', x.size())
>>> RuntimeError: The expanded size of the tensor (4) must match the existing size (2) at non-singleton dimension 2. Target sizes: [2, 3, 4].