张量数据类型
Data type | dtype | CPU tensor | GPU tensor |
---|---|---|---|
32-bit floating point | torch.float32 or torch.float | torch.FloatTensor | torch.cuda.FloatTensor |
64-bit floating point | torch.float64 or torch.double | torch.DoubleTensor | torch.cuda.DoubleTensor |
16-bit floating point | torch.float16 or torch.half | torch.HalfTensor | torch.cuda.HalfTensor |
8-bit integer(unsigned) | torch.uint8 | torch.ByteTensor | torch.cuda.ByteTensor |
8-bit integer(signed) | torch.int8 | torch.CharTensor | torch.cuda.CharTensor |
16-bit integer(signed) | torch.int16 or torch.short | torch.ShortTensor | torch.cuda.ShortTensor |
32-bit integer(signed) | torch.int32 or torch.int | torch.IntTensor | torch.cuda.IntTensor |
64-bit integer(signed) | torch.int64 or torch.long | torch.LongTensor | torch.cuda.LongTensor |
类型使用
示例1:类型比较
import torch
# 随机初始化
a = torch.randn(2, 3)
print("a.type(): ", a.type())
print("type(a): ", type(a))
print("type of 'a' is torch.FloatTensor: ",
isinstance(a, torch.FloatTensor))
示例2:同一数据被部署在CPU和GPU上类型不同
import torch
# 随机初始化
a = torch.randn(2, 3)
print(isinstance(a, torch.cuda.FloatTensor))
# 如果没有开启cuda会报错
a = a.cuda()
print(isinstance(a, torch.cuda.FloatTensor))
示例3:标量
import torch
a = torch.tensor(1.)
print("a: ", a)
b = torch.tensor(1.3)
print("b: ", b)
示例4:标量的shape
import torch
a = torch.tensor(2.2)
print("a.shape: ", a.shape)
print("len(a.shape): ", len(a.shape))
print("a.dim(): ", a.dim())
print("a.size(): ", a.size())
示例5:一维张量
import numpy as np
import torch
print(torch.tensor([1.1]))
print(torch.tensor([1.1, 2.2]))
print(torch.FloatTensor(1))
print(torch.FloatTensor(2))
data = np.ones(2)
print("data: ", data)
data = torch.from_numpy(data)
print("tensor from numpy:", data)
print("data.shape: ", data.shape)
print("data.dim(): ", data.dim())
print("data.size(): ", data.size())
示例6:二位张量
import torch
a = torch.randn(2, 3)
print("a:", a)
print("a.shape:", a.shape)
print("a.size():", a.size())
print("a.size(0):", a.size(0))
print("a.size(1):", a.size(1))
print("a.shape[1]:", a.shape[1])
print("a.dim():", a.dim())
示例7:三维张量
import torch
# 随机均匀分布初始化
a = torch.rand(2, 2, 3)
print("a:", a)
print("list(a.shape):", list(a.shape))
print("a.size():", a.size())
print("a.size(0):", a.size(0))
print("a.size(1):", a.size(1))
print("a.shape[1]:", a.shape[1])
print("a.dim():", a.dim())
示例8:获得总元素个数
import torch
# 随机均匀分布初始化
a = torch.rand(2, 2, 3)
# 返回a的总元素个数 2 * 2 * 3
print("a.numel():", a.numel())