主要涉及模块:
PyTorch,numpy
import torch
import numpy as np
# img_np,img_torch_tensor分别表示numpy.ndarray,torch.Tensor对象
# 1.numpy.ndarray转换元素类型为np.float32(可替换为numpy允许的数据类型)
img_np = np.float32(img_np)
# 2.torch.Tensor转换元素类型为torch.float32(可替换为torch允许的数据类型)
img_torch_tensor = img_torch_tensor.type(torch.float32)
img_torch_tensor = img_torch_tensor.to(torch.float32)