def linreg(X, w, b):
return torch.mm(X, w) + b
报错主要是这一句,第一次尝试将 X 的数据类型该为float,完全走错了方向(ps:老眼昏花报错没看清)
追溯到前文定义的 w 和 b:
原本设置为float32,别问为啥,我照抄的教材。
w = torch.tensor(np.random.normal(0, 0.01, (num_inputs, 1)), dtype=torch.float32)
b = torch.zeros(1, dtype=torch.float32)
修改为 double 类型即可通过。
w = torch.tensor(np.random.normal(0, 0.01, (num_inputs, 1)), dtype=torch.double)
b = torch.zeros(1, dtype=torch.double)
抄教材还是会出错哒!手动操作一边加深印象!可能还会有意想不到的问题出现。
如果是X 有问题可以直接加一句:
X = torch.tensor(X, dtype=torch.float32)