torch.set_default_dtype(d, /)
- d: 要设置的数据类型
将默认浮点数据类型设置为 d。PyTorch 初始默认浮点数据类型为torch.float32
。
- d:
torch.float16
、torch.float32
、torch.float64
- 当默认浮点数据类型为
torch.float16
时,默认复数数据类型为torch.complex32
- 当默认浮点数据类型为
torch.float32
时,默认复数数据类型为torch.complex64
- 当默认浮点数据类型为
torch.float64
时,默认复数数据类型为torch.complex128
- 当默认浮点数据类型为
torch.bfloat16
时,没有对应的默认复数数据类型。会引发异常
import torch
print(torch.tensor([1.2, 3]).dtype) # pytorch默认使用float32精度
print(torch.tensor([1.2, 3j]).dtype)# 复数默认使用complex64精度
import torch
torch.set_default_dtype(torch.float16)
print(torch.tensor([1.2, 3]).dtype) # 设置默认浮点精度为float16后,tensor的默认精度为float16
print(torch.tensor([1.2, 3j]).dtype)# 此时复数默认使用complex32精度
import torch
torch.set_default_dtype(torch.float64)
print(torch.tensor([1.2, 3]).dtype) # 设置默认浮点精度为float64后,tensor的默认精度为float64
print(torch.tensor([1.2, 3j]).dtype)# 此时复数默认使用complex128精度
import torch
torch.set_default_dtype(torch.bfloat16) # 设置默认浮点精度为bfloat16,会出现错误,因为bfloat16不支持复数
# 错误:RuntimeError: invalid default scalar type for complex