用法
torch.is_floating_point(input)
功能:判断输入中的值是否是浮点类型,浮点类型主要有:torch.float64,torch.float32,torch.float16,torch.bfloat16
参数
Input:输入必须是tensor,否则会报错且无法正常输出。
结果输出
如果输入input中的值为float类型,则返回True,否则返回False
例子
>>> a=torch.tensor([1,2],dtype=torch.float64)
>>> torch.is_floating_point(a)
True
>>> a=torch.tensor([1,2],dtype=torch.int8)
>>> torch.is_floating_point(a)
False