运行报错:
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #3 ‘mat2’ in call to _th_addmm_out
类型不一样
原代码:
display.clear_output(wait=True)
x=t.arange(0,20).view(-1,1)
y_pred=x.mm(w)+b.expand_as(x)
处理方式:
display.clear_output(wait=True)
x=t.arange(0,20).view(-1,1)
x=x.float()
y_pred=x.mm(w)+b.expand_as(x)