torch.dtype
torch.dtype是表示torch.Tensor的数据类型的对象。PyTorch有八种不同的数据类型:
使用方法:
(1)定义一个tensor
torch.tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False) → Tensor
例子:
torch.tensor([[0.11111, 0.222222, 0.3333333]],
dtype=torch.float64,
device=torch.device('cuda:0'))
>>> x = torch.Tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
>>> print x.type()
torch.FloatTensor #默认是float32
(2)定义一个全0矩阵
注意默认的都是FloatTensor
torch.zeros(*size, out=None, dtype=None, layout=