8.第八个方法
torch.set_default_tensor_type(t)
- 这个方法的意思是设置pytorch中默认的浮点类型,一般使用pytorch进行运算时候使用的都是浮点数来进行计算,所以设置默认浮点数有时候也很重要。
- 虽然这个方法和曾经的
torch.set_default_dtype(d)
确实功能很相似,但是实际上今天介绍的这个方法更强大一些(注意两个方法都只可以设置浮点数的默认类型,不可以设置整型的默认类型)。当然这个方法使用后也可以使用torch.get_default_dtype()来获取设置的默认浮点类型。 - Tensor有不同的数据类型,每种类型分别有对应CPU和GPU版本(HalfTensor除外)。默认的Tensor是FloatTensor,可通过torch.set_default_tensor_type修改默认tensor类型(如果默认类型为GPU tensor,则所有操作都将在GPU上进行),HalfTensor是专门为GPU设计的,相同元素个数使用的空间更少,解决显存不足的问题,但是由于精度不足可能会出现溢出的问题。
pytorch中可用的浮点类型
数据类型 | CPU Tensor | GPU Tensor |
---|
32 bit 浮点 | torch.FloatTensor | torch.cuda.FloatTensor |
64 bit 浮点 | torch.DoubleTensor | torch.cuda.DoubleTensor |
16 bit 单精度浮点 | 无 | torch.cuda.HalfTensor |
- 我们可以通过
torch.set_default_tensor_type(t)
来将我们的默认浮点类型设置为cuda类型,这样以后我们就不需要将我们的数据迁移到cuda上了,直接就可以使用GPU加速。但是此方法肯可能会造成一些不好的结果,可能有的时候还得需要自己将tensor迁回cpu,所以是使用device迁到cuda上还是使用着这种默认的方法就因人而异了。
使用方法如下:
import torch
torch.set_default_tensor_type(torch.cuda.FloatTensor)
a = torch.tensor([1., 3])
print(a.dtype,a.device)
- 结果为:

符合预期