.RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 ‘mat2’
if i%1000 == 0:
display.clear_output(wait=True)
x = t.arange(0,20).view(-1, 1)
print(x.dtype)
# print(x)
y = x.mm(w) + b.expand_as(x)
plt.plot(x.numpy(), y.numpy())
x2, y2 = get_fake_data(batch_size=20)
plt.scatter(x2.numpy(), y2.numpy())
plt.xlim(0, 20)
plt.ylim(0, 41)
plt.show()
plt.pause(0.5)
第6行y = x.mm(w) + b.expand_as(x) 报错。类型不一致。x是torch.int,w是torch.float
解决方法:将x类型转换。x =x.float() 即可。