在调试用pytorch和numpy写的深度学习代码时,要注意pytorch与numpy默认的浮点类型不一样:
# pytorch与numpy默认的整数类型一样
a = torch.tensor([1, 2, 3, 4])
print(a.dtype) # 查看张量数据类型
b = torch.tensor(np.array([1, 2, 3, 4]))
print(b.dtype)
# pytorch与numpy默认的浮点类型不一样
a = torch.tensor([1., 2., 3., 4.])
print(a.dtype) # 查看张量数据类型
b = torch.tensor(np.array([1., 2., 3., 4.]))
print(b.dtype)
输出结果为:
torch.int64
torch.int64
torch.float32
torch.float64