报错代码
原因
打印了下报错行代码:
labels = tensor([[3.]], dtype=torch.float64) 原来是tensor的维度错误,程序期望使用1维数据
解决方法
labels = labels.squeeze(1) # squeeze(i)代表将张量中对应的第i层的维度进行降维
代码成功运行
打印了下报错行代码:
labels = tensor([[3.]], dtype=torch.float64) 原来是tensor的维度错误,程序期望使用1维数据
labels = labels.squeeze(1) # squeeze(i)代表将张量中对应的第i层的维度进行降维
代码成功运行