torchvision.transforms中定义了一系列数据转换形式,有PILImage,numpy,Tensor间相互转换,还能对数据进行处理。
在torchvision.datasets下载数据的时候,作为一个参数传入,对下载的数据进行处理
(关于数据集下载,具体可参见https://blog.csdn.net/qq_37385726/article/details/81771943)
目录
1.torchvision.transforms.ToTensor()
把一个取值范围是[0,255]的PIL.Image 转换成 Tensor
shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的Tensor
2.torchvision.transforms.ToPILImage()
3.综合处理,ndarray->Tensor->PILImage
深入理解Ptorchvision.transforms.ToTensor与ToPILImage
1.torchvision.transforms.ToTensor()
将PILImage或者numpy的ndarray转化成Tensor
- 对于PILImage转化的Tensor,其数据类型是torch.FloatTensor
- 对于ndarray的数据类型没有限制,但转化成的Tensor的数据类型是由ndarray的数据类型决定的。
-
把一个取值范围是[0,255]的PIL.Image 转换成 Tensor
img1 = Image.open('./Image/use_Crop.jpg')
t_out = transforms.ToTensor()(img1)
-
shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的Tensor
n_out = np.random.rand(100,100,3)
print(n_out.dtype)
t_out = transforms.ToTensor()(n_out)
print(t_out.type())
输出
float64
torch.DoubleTensor
2.torchvision.transforms.ToPILImage()
将Numpy的ndarray或者Tensor转化成PILImage类型【在数据类型上,两者都有明确的要求】
- ndarray的数据类型要求dtype=uint8, range[0, 255] and shape H x W x C
- Tensor 的shape为 C x H x W 要求是FloadTensor的,不允许DoubleTensor或者其他类型
# to a PIL.Image of range [0, 255]
-
将ndarray转化成PILImage
#初始化随机数种子
np.random.seed(0)
data = np.random.randint(0, 255, 300)
print(data.dtype)
n_out = data.reshape(10,10,3)
#强制类型转换
n_out = n_out.astype(np.uint8)
print(n_out.dtype)
img2 = transforms.ToPILImage()(n_out)
img2.show()
-
将Tensor转化成PILImage
t_out = torch.randn(3,10,10)
img1 = transforms.ToPILImage()(t_out)
img1.show()
因为要求是FloatTensor类型,所以最好是不管此时的tensor是不是FloatTensor类型,都加一个强制转换再传进去。
t_out = torch.randn(3,10,10)
img1 = transforms.ToPILImage()(t_out.float())
img1.show()
3.综合处理,ndarray->Tensor->PILImage
n_out = np.random.rand(100,100,3)
t_out = transforms.ToTensor()(n_out)
img2 = transforms.ToPILImage()(t_out.float()) #强制类型转换
img2.show()
深入理解Ptorchvision.transforms.ToTensor与ToPILImage
https://blog.csdn.net/qq_37385726/article/details/81811466