最近码代码使用pytorch遇到如题所示的问题,查遍Google百度,大多是说运算时维度不符,但是我找遍代码也没发现有这个错误。一段时间后才发现,网络参数保存的是torch.float32类型,而我输入的数据是torch.float64类型,将数据类型更改为torch.float32,问题解决。
这种错误有可能导致如题的错误,还有可能导致所有数据,也就是tensor的data项都是nan,最坑的是pytorch并没有相关的报错,只能用pdb进行debug。
最近码代码使用pytorch遇到如题所示的问题,查遍Google百度,大多是说运算时维度不符,但是我找遍代码也没发现有这个错误。一段时间后才发现,网络参数保存的是torch.float32类型,而我输入的数据是torch.float64类型,将数据类型更改为torch.float32,问题解决。
这种错误有可能导致如题的错误,还有可能导致所有数据,也就是tensor的data项都是nan,最坑的是pytorch并没有相关的报错,只能用pdb进行debug。