有两种情况:
1、(a,1)+(b) = (a,b)
x=torch.randn(2,1)
y=torch.randn(5)
a=x+y
print(x)
print(y)
print(a)
print(a.shape)
结果是
tensor([[1.0769],
[2.5189]])
tensor([ 1.6656, -1.5514, 0.2980, -0.4826, 0.1603])
tensor([[ 2.7425, -0.4745, 1.3748, 0.5942, 1.2371],
[ 4.1846, 0.9675, 2.8169, 2.0363, 2.6792]])
torch.Size([2, 5])
先把(a,1)扩展为(a,b)。
扩展的规则如上式为
tensor([[1.0769],
[2.5189]])
扩展后为:
tensor([[ 1.769, 1.769, 1.768, 1.768, 1.768],
[ 2.5189, 2.5189, 2.5189, 2.5188, 2.5189]])
然后将(a,b)与b中元素对应相加
2、(a,b)+(b) = (a,b)
直接将(a,b)与b中元素对应相加