torch.unit8 : G = torch.Tensor(H).byte().to(device)
torch.int16: G = torch.Tensor(H).short().to(device)
设置: 通过一些内置函数,可以实现对tensor的精度, 类型,print打印参数等进行设置
torch.set_default_dtype(d) #对torch.tensor() 设置默认的浮点类型
torch.set_default_tensor_type() # 同上,对torch.tensor()设置默认的tensor类型
>>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32
torch.float32
>>> torch.set_default_dtype(torch.float64)
>>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
torch.float64
>>> torch.set_default_tensor_type(torch.DoubleTensor)
>>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
torch.float64
torch.get_default_dtype() #获得当前默认的浮点类型torch.dtype
torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None)#)
## 设置printing的打印参数
对nn. 设置参数类型
lin = nn.Linear(10, 10).double()