关于不同Tensor类型的差异
在pytorch中,我们处理的变量都属于Tensor,而在官方文档中,tensor有多个类型可选。但实际上,我们最常用的还是FloatTenso
r和LongTensor
。
下面我们就来简要介绍这两类tensor的使用。
- 在定义tensor时,如果你不指定具体类型,那tensor变量为默认类型,
torch.FloatTensor
。 - 对于tensor之间的操作,规则是严格的。要求操作的两个张量必须具有相同的数据类型,否则程序会报错。
- 再举一个例子,像
CrossEntropyLoss
这样的损失函数要求目标应该是LongTensor
。所以在进行操作之前,要保证输入Tensor 类型与函数定义匹配。
Tensor 类型转换
如下官方文档所说,我们只需to
我们想要的类型即可。
Performs Tensor dtype and/or device conversion. A torch.dtype and torch.device are inferred from the arguments of self.to(*args, **kwargs).
示例代码如下:
>>> tensor = torch.randn(2, 2) # Initially dtype=float32, device=cpu
>>> tensor.to(torch.float64)
tensor([[-0.5044, 0.0005],
[ 0.3310, -0.0584]], dtype=torch.float64)
FloatTensor vs DoubleTensor
当前英伟达芯片大多只支持单精度,而且单精度与双精度这种精度差异对模型的准确率影响不大。