Torch 定义了九种CPU tensor类型和九种GPU tensor类型
Data type | dtype | CPU tensor | GPU tensor |
---|
32位浮点型 | torch.float32 或 torch.float | torch.FloatTensor | torch.cuda.FloatTensor |
64位浮点型 | torch.float64 或 torch.double | torch.DoubleTensor | torch.cuda.DoubleTensor |
16位浮点型 | torch.float16 或 torch.half | torch.HalfTensor | torch.cuda.HalfTensor |
8位整型 (无符号) | torch.uint8 | torch.ByteTensor | torch.cuda.ByteTensor |
8位整型 (有符号) | torch.int8 | torch.CharTensor | torch.cuda.CharTensor |
16位整型 (有符号) | torch.int16 或 torch.short | torch.ShortTensor | torch.cuda.ShortTensor |
32位整型 (有符号) | torch.int32 或 torch.int | torch.IntTensor | torch.cuda.IntTensor |
64位整型 (有符号) | torch.int64 或 torch.long | torch.LongTensor | torch.cuda.LongTensor |
布尔型 | torch.bool | torch.BoolTensor | torch.cuda.BoolTensor |
torch.Tensor
是默认的tensor类型 (torch.FloatTensor
)的别名。