用法介绍
pytorch的数据类型比较繁杂,有tensor内的数据类型,tensor在CPU中的数据类型以及tensor在GPU中的数据类型,它们之间的对应关系如下表格所示
Data type | type | CPU tensor | GPU tensor |
64-bit floating point | torch.float64 /torch.double | torch.DoubleTensor | torch.cuda.DoubleTensor |
32-bit floating point | torch.float16 /torch.half | torch.HalfTensor | torch.cuda.HalfTensor |
8-bit integer (unsigned) | torch.uint8 | torch.ByteTensor | torch.cuda.ByteTensor |
8-bit integer (signed) | torch.int8 | torch.CharTensor | torch.cuda.CharTensor |
16-bit integer (signed) | torch.int16/torch.short | torch.ShortTensor | torch.cuda.ShortTensor |
32-bit integer (signed) | torch.int32/torch.int | torch.IntTensor | torch.cuda.IntTensor |
64-bit integer (signed) | torch.int64/torch.long | torch.LongTensor | torch.cuda.LongTensor |
Boolean | torch.bool | torch.BoolTensor | torch.cuda.BoolTensor |
pytorch中对应的数据类型转换函数有tensor1.type_as(tensor2),tensor.type(torch.IntTensor),tensor.long(),tensor.char(),tensor.int(),tensor.byte(),tensor.double(),tenosr.to(torch.long)
代码实例
pytorch的数据类型定义以及对应的数据类型转换的程序如下所示
>>> import torch
>>> torch.zeros((2, 4), dtype=torch.int32)
tensor([[0, 0, 0, 0],
[0, 0, 0, 0]], dtype=torch.int32)
>>> torch.zeros((2, 4), dtype=torch.float32)
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.]])
>>> torch.zeros((2, 4), dtype=torch.uint8)
tensor([[0, 0, 0, 0],
[0, 0, 0, 0]], dtype=torch.uint8)
>>> torch.zeros((2, 4), dtype=torch.float64)
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=torch.float64)
>>> A = torch.zeros((2, 4), dtype=torch.int32)
>>> type(A)
<class 'torch.Tensor'>
>>> B = torch.zeros((2, 4), dtype=torch.float32).type_as(A)
>>> print(B)
tensor([[0, 0, 0, 0],
[0, 0, 0, 0]], dtype=torch.int32)