numpy中默认浮点类型为64位,pytorch中默认浮点类型位32位
测试代码如下
numpy版本:1.19.2
pytorch版本:1.2.0
In [1]: import torch
In [2]: import numpy as np
# 版本信息
In [3]: "pytorch version: {}, numpy version: {}".format(torch.__version__, np.__version__)
Out[3]: 'pytorch version: 1.2.0, numpy version: 1.19.2'
# numpy
In [4]: dat_np = np.array([1,2,3], dtype="float")
In [5]: dat_np.dtype
Out[5]: dtype('float64')
# pytorch
In [6]: dat_torch = torch.tensor([1,2,3])
In [7]: dat_torch = dat_torch.float()
In [8]: dat_torch.dtype
Out[8]: torch.float32