import torch
a = torch.randn(3, 200, 200)
print(a.dtype)
b = a.type(torch.float16)
print(b.dtype)
c = a.type(torch.int32)
print(c.dtype)
d = a.type(torch.long)
print(d.dtype)
e = a.type(torch.float32)
print(e.dtype)
pytorch 修改tensor数据类型
最新推荐文章于 2024-04-02 08:00:00 发布