torch.from_numpy VS torch.Tensor
最近在造dataset的时候,突然发现,在输入图像转tensor的时候,我可以用torch.Tensor
直接强制转型将numpy类转成tensor类,也可以用torch.from_numpy
这个方法将numpy类转换成tensor类,那么,torch.Tensor
和torch.from_numpy
这两个到底有什么区别呢?既然torch.Tensor
能搞定,那torch.from_numpy
留着不就是冗余吗?
答案
有区别,使用torch.from_numpy
更加安全,使用tensor.Tensor
在非float类型下会与预期不符。
解释
实际上,两者的区别是大大的。打个不完全正确的比方说,torch.Tensor
就如同c的int
,torch.from_numpy
就如同c++的static_cast
,我们都知道,如果将int64强制转int32,只要是高位转低位,一定会出现高位被抹去的隐患的,不仅仅可能会丢失精度,甚至会正负对调。这里的torch.Tensor
与torch.from_numpy
也会存在同样的问题。
看看torch.Tensor的文档,里面清楚地说明了,
torch.Tensor is an alias for the default tensor type (torch.FloatTensor).
而torch.from_numpy的文档则是说明,
The returned tensor and ndarray share the same memory. Modifications to the tensor will be reflected in the ndarray and vice versa. The returned tensor is not resizable.
也即是说,
- 当转换的源是float类型,
torch.Tensor
与torch.from_numpy
会共享一块内存!且转换后的结果的类型是torch.float32 - 当转换的源不是float类型,
torch.Tensor
得到的是torch.float32,而torch.from_numpy
则是与源类型一致!
是不是很神奇,下面是一个简单的例子:
import torch
import numpy as np
s1 = np.arange(10, dtype=np.float32)
s2 = np.arange(10) # 默认的dtype是int64
# 例一
o11 = torch.Tensor(s1)
o12 = torch.from_numpy(s1)
o11.dtype # torch.float32
o12.dtype # torch.float32
# 修改值
o11[0] = 12
o12[0] # tensor(12.)
# 例二
o21 = torch.Tensor(s2)
o22 = torch.from_numpy(s2)
o21.dtype # torch.float32
o22.dtype # torch.int64
# 修改值
o21[0] = 12
o22[0] # tensor(0)