cv2, PIL, plt图片读取,并且与Tensor相互转化
- 读取图片
- 图片展示
- ndarray和PIL相互转化
- Tensor与ndarray和PIL相互转化
为什么要介绍这些呢?
通过下面的学习,我们在在学习深度学习过程中想要初步了解数据集并随意挑出图片进行展示,通过Tensor转化为ndarray或者PIL查看图片。
1. 图片读取
本文介绍三种图片读取方式,分别是:cv2, plt, Image。
img_path = './cat.jpg'
img_cv = cv2.imread(img_path)
img_plt = plt.imread(img_path)
img_PIL = Image.open(img_path)
print(type(img_cv)) # <class 'numpy.ndarray'>
print(type(img_plt)) # <class 'numpy.ndarray'>
print(type(img_PIL)) # <class'PIL.JpegImagePlugin.JpegImageFile'>
从上方代码可以看出,不同读取方式读取后的数据格式不同,为方便大家查看,下面给大家用表格形式进行查看。
读取方式 | 读入后的格式 | HWC顺序 | RGB顺序 |
---|---|---|---|
cv2.imread(path) | ndarray | HWC | BGR |
plt.imread(path) | ndarray | HWC | RGB |
Image.open(path) | PIL | HWC | RGB |
2. 图片展示
在这里,我选择使用plt.imshow()进行展示,比较方便。
❗要注意:plt.展示需要的是RGB顺序,所以想要展示cv2.imread()后的图片会出现问题,但是只要split, merge一下就ok。
plt.figure(figsize=(10, 30))
# cv2
plt.subplot(1, 3 , 1)
plt.title('cv2')
# b, g, r = cv2.split(img_cv)
# img_cv = cv2.merge([r, g, b])
plt.imshow(img_cv)
plt.xticks([]);plt.yticks([])
# plt
plt.subplot(1, 3 ,2)
plt.title('plt')
plt.imshow(img_plt)
plt.xticks([]);plt.yticks([])
#PIL
plt.subplot(1, 3, 3)
plt.title('PIL')
plt.imshow(img_PIL)
plt.xticks([]);plt.yticks([])
结果如上所示,由于我们cv2是BGR格式,而plt.imshow()是RGB格式,所以’cv2‘的图片有所差异。
解决方案如下:
这两句代码的意思就是将cv2的BGR格式转换成RGB格式,我们再运行一遍,会发现三个照片都一样了。
3. ndarray和PIL相互转换
目的 | 方式 |
---|---|
ndarray -> PIL | img_PIL = Image.fromarray(img_np.astype(‘unit8’)).convert(‘RGB’) |
PIL -> ndarray | img_np = np.array(img_PIL) |
举例:
""" 将PIL转化为ndarray"""
img_PIL = Image.open(img_path)
print(type(img_PIL))# <class 'PIL.JpegImagePlugin.JpegImageFile'>
img_np = np.array(img_PIL)
print(type(img_np) )# <class 'numpy.ndarray'>
""" 将ndarray转化成PIL """
img_PIL_2 = Image.fromarray(img_np.astype('uint8')).convert("RGB")
print(type(img_PIL_2)) # <class 'PIL.Image.Image'>
4. Tensor与ndarray和PIL相互转化
目的 | 方式 |
---|---|
ndarray、PIL -> Tensor | transforms.Totensor() |
Tensor -> PIL | transforms.ToPILImage() |
Tensor - > ndarray | *.array() |
举例:
4.1 Tensor与PIL相互转化
"""PIL -> Tensor -> PIL -> 显示"""
Totensor = transforms.ToTensor()
ToPIL = transforms.ToPILImage()
img_PIL = Image.open(img_path)
print(type(img_PIL)) # <class 'PIL.JpegImagePlugin.JpegImageFile'
img_tensor = Totensor(img_PIL)
print(type(img_tensor), img_tensor.shape)
# <class 'torch.Tensor'> torch.Size([3, 622, 886])
plt.figure(figsize=(10, 10))
img_ten_PIL = ToPIL(img_tensor)
print(type(img_ten_PIL)) # <class 'PIL.Image.Image'>
plt.imshow(img_ten_PIL)
plt.xticks([]);plt.yticks([])
plt.show()
❗注意:Tensor通过ToPILImage()进行转换,通道都不用改变,转换后的数据类型可以直接通过plt进行显示。
4.2 Tensor与ndarray相互转化
这里用cv2.imread()进行读取数据,并转化为tensor然后再次转化为ndarray然后进行显示。
"""ndarray -> tensor -> ndarray -> 显示"""
Totensor = transforms.ToTensor()
img_cv = cv2.imread(img_path)
print('转换前的数据类型:', type(img_cv), img_cv[6][6])
# 转换前的数据类型: <class 'numpy.ndarray'> [29 32 23]
img_ten = Totensor(img_cv)
print(type(img_ten))
# <class 'torch.Tensor'>
# 注意:Totensor会将数据转化为0 - 1 之间,使用255和astype进行还原
img_ten_np = img_ten.numpy()*255
img_ten_np = img_ten_np.astype('uint8')
# tensor 是CHW格式,但是plt.imshow()显示的是HWC格式,需进行转换
img_ten_np = np.transpose(img_ten_np, (1, 2, 0))
# cv2.imread()是BGR格式,plt.show()显示的是RGB格式,需进行转换
b, g, r = cv2.split(img_ten_np)
img_ten_np = cv2.merge([r, g, b])
print("转换后的数据类型:", type(img_ten_np), img_ten_np[6][6])
# 转换后的数据类型: <class 'numpy.ndarray'> [23 32 29]
plt.imshow(img_ten_np)
plt.show()
将cv2的图片进行转换为Tensor再转化为图片,如果想要正常显示,需要乘以255并且astype(‘uint8’)。同时,因为Tensor是CHW格式,如果要转为ndarray,需要通过np.transpose(img, (1,2,0))转换一下维度。
4.3 plt -> tensor - > plt -> 显示
plt的话就更简单,不会像cv2那样需要利用split和merge改变颜色通道。
"""plt -> tensor -> plt -> 显示"""
img_plt = plt.imread(img_path)
print(type(img_plt))
img_ten = Totensor(img_plt)
print(type(img_ten))
img_ten_plt = img_ten.numpy()*255
img_ten_plt = img_ten_plt.astype('uint8')
img_ten_plt = np.transpose(img_ten_plt, (1, 2, 0))
plt.imshow(img_ten_plt)
plt.show()