引用官网里的内容:[“ torch.Tensor是torch.FlaotTensor ”]
- Tensor默认生成FloatTensor类型的数据,tensor默认生成LongTensor类型的数据
import torch
data = [1, 2, 3, 4]
tensor_result = torch.tensor(data)
print("tensor:", tensor_result, "|", "tensor生成的数据类型:", tensor_result.type(), "|", "dtpye:", tensor_result.dtype)
Tensor_result = torch.Tensor(data)
print("Tensor:", Tensor_result, "|", "Tensor生成的数据类型:", Tensor_result.type(), "|", "dtpye:", Tensor_result.dtype)
result:
torch.Tensor 是一个类
而 torch.tensor是一个函数
-
Class Tensor下包含很多函数。
-
torch.tensor的函数原型:
torch.tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False)
使用
torch.tensor()
总是会从传入参数data
中拷贝数据. 假如你有tensor数据想避免被拷贝,可以使用torch.Tensor.requires_grad_()
ortorch.Tensor.detach()
. 如果你有numpy的数组,并且要避免备拷贝,可以使用torch.as_tensor()
.