错误:
代码:
xb = x_train[:64]
xb = xb.reshape(-1,28*28)
# print(xb.sha)
yb = y_train[:64]
weights = torch.randn([784,10],dtype=torch.float,requires_grad=True)
bias = torch.zeros(10,requires_grad=True)
print(loss_func(model(xb),yb))
原因:数据类型应该为float
修改:
xb = xb.reshape(-1,784).float()