![31ee3b5abdc62414011efbd9e81b23d1.png](https://img-blog.csdnimg.cn/img_convert/31ee3b5abdc62414011efbd9e81b23d1.png)
1. 从数组、列表对象创建
Numpy Array 数组和 Python List 列表是 Python 程序中间非常重要的数据载体容器,很多数据都是通过 Python 语言将数据加载至 Array 数组或者 List 列表容器,再转换到 Tensor 类型。(为了方便描述,后面将 Numpy Array 数组称为数组,将 Python List 列表称为列表。)
PyTorch 从数组或者列表对象中创建 Tensor 有四种方式:
- torch.Tensor
- torch.tensor
- torch.as_tensor
- torch.from_numpy
>>> import torch
>>> import numpy as np
>>> array = np.array([1, 2, 3])
>>> list = [4, 5, 6]
# 方式一:使用torch.Tensor类
>>> tensor_array_a = torch.Tensor(array)
>>> tensor_list_a = torch.Tensor(list)
>>> print(isinstance(tensor_array_a, torch.Tensor)
, tensor_array_a.type())
True torch.FloatTensor
>>> print(isinstance(tensor_list_a, torch.Tensor)
, tensor_list_a.type())
True torch.FloatTensor
# 方式二:使用torch.tensor函数
>>> tensor_array_b = torch.tensor(array)
>>> tensor_list_b = torch.tensor(list)
>>> print(isinstance(tensor_array_b, torch.Tensor)
, tensor_array_b.type())
True torch.LongTensor
>>> print(isinstance(tensor_list_b, torch.Tensor)
, tensor_list_b.type())
True torch.LongTensor
# 方式三:使用torch.as_tensor函数
>>> tensor_array_c = torch.as_tensor(array)
>>> tensor_list_c = torch.as_tensor(list)
>>> print(isinstance(tensor_array_c, torch.Tensor)
, tensor_array_c.type())
True torch.LongTensor
>>> print(isinstance(tensor_list_c, torch.Tensor)
, tensor_list_c.type())
True torch.LongTensor
# 方式四:使用torch.from_numpy函数
>>> tensor_array_d = torch.from_numpy(array)
# tensor_list_d = torch.from_numpy(list) error code
>>> print(isinstance(tensor_array_d, torch.Tensor)
, tensor_array_d.type())
True torch.LongTensor
# print(isinstance(tensor_list_d, torch.Tensor)
# , tensor_list_d.type())
通过上面代码的执行结果可以简单归纳出四种创建 Tensor 方式的差异:
- 只有 torch.Tensor 是类,其余的三种方式都是函数;
- torch.Tensor、torch.tensor 和 torch.as_tensor 三种方式可以将数组和列表转换为 Tensor,但是 torch.from_numpy 只能将数组转换为 Tensor(为 torch.from_numpy 函数传入列表,程序会报错);
- 从程序的输出结果可以看出,四种方式最终都将数组或列表转换为 Tensor(使用 isinstance 返回的结果都为 True),**但是转换后的 Tensor 数据类型却有所不同,在上一小节区分 torch.Tensor 和 torch.tensor 的时候提到过,当接收数据内容时,torch.Tensor 创建的 Tensor 会使用默认的全局数据类型,而 torch.tensor 创建的 Tensor 会使用根据传入数据推断出的数据类型。**可以通过
torch.get_default_dtype()
来获取当前的全局数据类型,也可以通过torch.set_default_dtype(torch.XXXTensor)
来设置当前环境默认的全局数据类型;
>>> import torch
>>> import numpy as np
>>> array = np.array([1, 2, 3])
>>> print(array)
int64
# 获取当前全局环境的数据类型
>>> print(torch.get_default_dtype())
torch.float32
# 方式一:使用torch.Tensor类
>>> tensor_array_a = torch.Tensor(array)
>>> print(tensor_array_a.type())
torch.FloatTensor
# 方式二:使用torch.tensor函数
>>> tensor_array_b = torch.tensor(array)
>>> print(tensor_array_b.type())
torch.LongTensor
# 设置当前全局环境的数据类型为torch.DoubleTensor
>>> torch.set_default_tensor_type(torch.DoubleTensor)
>>> tensor_array_a = torch.Tensor(array)
>>> print(tensor_array_a.type())
torch.DoubleTensor
>>> tensor_array_b = torch.tensor(array)
>>> print(tensor_array_b.type())
torch.LongTensor
**PyTorch 默认的全局数据类型为 torch.float32,因此使用 torch.Tensor 类创建 Tensor 的数据类型和默认的全局数据类型一致,为 torch.FloatTensor,而使用 torch.tensor 函数创建