X = np.random.randn(100, 10)
W = np.random.randn(10, 64)
b = np.ones(64)
z = X @ W + b # Works
用torch执行会报错,因为broadcasting机制还没有实现;
X = torch.randn(100, 10)
W = torch.randn(10, 64)
b = torch.ones(64)
z = X @ W + b # Error, cannot add tensor of size [100, 64] and [64]
处理方法是用repeat方法:
X = torch.randn(100, 10)
W = torch.randn(10, 64)
b = torch.ones(64)
z = X @ W + b.repeat(X.size(0), 1)
例子:
b = torch.ones(1)
b.repeat(2,5)
1 1 1 1 1
1 1 1 1 1
[torch.FloatTensor of size 2x5]