pytorch DataLoader(1): opencv,skimage,PIL,Tensor转换以及transforms

更新

在这里插入图片描述
在这里插入图片描述

本文进入热榜收到了不少关注,所以将本文的代码放在了GitHub上,jupyter的,有需要的自取。

同时也欢迎查看后续更新:

pytorch DataLoader(2): Dataset,DataLoader自定义训练数据_opencv,skimage,PIL接口
pytorch DataLoader(3)_albumentations数据增强(分割版)

前置知识

在使用pytorch进行dataload,transform之前,需要了解一些数据的知识,许多人使用不同的接口因为不熟悉犯了一些错误。在这里对一些常用的OpenCV,PIL,skimage进行了一些总结,以及pytorchvision.transorforms的一些简单使用。

import cv2
from PIL import Image
from skimage import io, transform, color
import matplotlib.pyplot as plt
import numpy as np

from torchvision import transforms
img_path = 'data/1803151818-00000065.jpg'
alpha_path = 'data/1803151818-00000065.png'

常用接口

1.1 OpenCV
# 默认彩图
img_cv2 = cv2.imread(img_path)

# 灰度图
img_cv2_gray = cv2.imread(alpha_path,0)

print(img_cv2.shape)
# (250, 250, 3)  (H,W,C)

type(img_cv2)
# numpy.ndarray
1.2 PIL.Image
# 默认彩图
img_pil = Image.open(img_path)

# 灰度图
img_pil_gray = Image.open(alpha_path).convert('L') # 打开图片并转成灰度图

print(img_pil.size) 
# (250, 250)

print(np.array(img_pil).shape) # PIL没有shape属性,需要转成 numpy.ndarray
#(250, 250, 3)

type(img_pil)
# PIL.JpegImagePlugin.JpegImageFile  HWC

1.3 skimage1
# 默认彩图
img_skimage = io.imread(img_path)

# 灰度图
img_skimage_gray = io.imread(alpha_path,-1)

print(img_skimage.shape)
# (250, 250, 3)

type(img_skimage)
# numpy.ndarray
# imageio.core.util.Array

(800, 600, 3)

numpy.ndarray
1.4 小结
  • OpenCV读进来的是numpy数组,是uint8类型,0-255范围,图像形状是(H,W,C),读入的顺序是BGR,这点需要注意
  • PIL是有自己的数据结构的,类型是<class ‘PIL.Image.Image’>;但是可以转换成numpy数组,转换后的数组为unit8,0-255范围,图像形状是(H,W,C),读入的顺序是RGB
  • skimage读取进来的图片是numpy数组,是unit8类型,0-255范围,图像形状是(H,W,C),读入的顺序是RGB
  • matplotlib读取进来的图片是numpy数组,是unit8类型,0-255范围,图像形状是(H,W,C),读入的顺序是RGB
名称type数据类型读入图像格式数据形状能否通过transforms转换
opencvnumpy.ndarrayuint8类型,0-255范围BGRH×W×C
PILPIL.Image.ImageRGBH×W×C
skimagenumpy.ndarrayuint8类型,0-255范围RGBH×W×C
#cv2
# cv2 BGR-->RGB 两种方法
#img_cv2 = img_cv2[:,:,::-1]
img_cv2 = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)

plt.subplot(1,4,1)
plt.title('cv2')
plt.imshow(img_cv2)
 
#PIL
plt.subplot(1,4,2)
plt.title('PIL')
plt.imshow(img_pil)

#PIL
plt.subplot(1,4,3)
plt.title('skimage')
plt.imshow(img_skimage)

#plt
img = plt.imread(img_path)
plt.subplot(1,4,4)
plt.title('plt')
plt.imshow(img_pil)

#show
plt.show()

在这里插入图片描述

2. 相互转换

2.1 opencv <—> pil
img_cv = cv2.imread(img_path)
img_pil = Image.open(img_path)
img_skimage = io.imread(img_path)

# opencv -> pil
img_pil = Image.fromarray(cv2.cvtColor(img_cv,cv2.COLOR_BGR2RGB))

# pil -> opencv
img_cv = cv2.cvtColor(np.asarray(img_pil),cv2.COLOR_RGB2BGR)

2.2 skimage <—> pil
# skimage -> pil
img_pil = Image.fromarray(img_skimage)

# pil -> skimage
img_pil = np.array(img_skimage)
2.3 skimage <—> opencv
# opencv -> skimage
img_skimage = cv2.cvtColor(img_cv,cv2.COLOR_BGR2RGB)

# skimage -> opencv
from skimage import img_as_ubyte
cv_image = img_as_ubyte(img_skimage)

3. transforms, tensor转换

为了方便进行图像数据的操作,pytorch团队提供了一个torchvision.transforms包,我们可以用transforms进行以下操作:

  • PIL.Image / numpy.ndarray与Tensor的相互转化;
  • 归一化;
  • 对PIL.Image进行裁剪、缩放等操作。

注意1: transforms.ToTensor() 可以将 PIL.Image/numpy.ndarray 数据进转化为torch.FloatTensor,并归一化到[0, 1.0],但是transforms的其他操作只能对PIL读入的数据操作,所以使用transforms.Compose()将这些操作组合到一起的如果有其他操作则只能输入PIL数据。
transforms包含多种图像操作的函数,可以单独使用,也可以通过transforms.Compose([function1, function2,……functionN])操作。
注意2:Tensor的形状是[C,H,W],而cv2,plt,PIL,skimage形状都是[H,W,C]

3.1 H×W×C ——> C×H×W
img_cv2.transpose(2,0,1).shape
# (3,250, 250)

img_skimage.transpose(2,0,1).shape
# (3,250, 250)
(3, 800, 600)
3.2 toTensor
  • PIL.Image / numpy.ndarray --> Tensor: train 数据读取
  • Tensor --> PIL.Image / numpy.ndarray: inference 数据输出。

我们可以使用 transforms.ToTensor() 将 PIL.Image/numpy.ndarray 数据进转化为torch.FloatTensor,并归一化到[0, 1.0]:

  • 取值范围为[0, 255]的PIL.Image,转换成形状为[C, H, W],取值范围是[0, 1.0]的torch.FloatTensor;
  • 形状为[H, W, C]的numpy.ndarray,转换成形状为[C, H, W],取值范围是[0, 1.0]的torch.FloatTensor;
  • transforms.ToPILImage则是将Tensor或numpy.ndarray转化为PIL.Image。如果,我们要将Tensor转化为numpy,只需要使用 .numpy() 即可。
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
 
img_path = 'data/1803151818-00000065.jpg'
 
# transforms.ToTensor()
transform1 = transforms.Compose([
    transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0] and convert [H,W,C] to [C,H,W]
])
img = plt.imread(img_path)
print('plt',img.shape)       #(H,W,C)
img = transform1(img)  
print(img.shape)         #torch.Size([C,H,W])
# 转化为numpy.ndarray并显示
img_arr = img.numpy() * 255  #use np.numpy(): convert Tensor to numpy
img_arr = img_arr.astype('uint8')  #convert Float to Int
print(img_arr.shape)                #[C,H,W]
img_new = np.transpose(img_arr, (1, 2, 0))   #use np.transpose() convert [C,H,W] to [H,W,C]
plt.imshow(img_new)
plt.show()
plt (800, 600, 3)
torch.Size([3, 800, 600])
(3, 800, 600)

在这里插入图片描述

img = cv2.imread(img_path)
#img = img[:,:,::-1] ### ValueError???
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

print('plt',img.shape)       #(H,W,C)
img = transform1(img)  
print(img.shape)         #torch.Size([C,H,W])
# 转化为numpy.ndarray并显示
img_arr = img.numpy() * 255  #use np.numpy(): convert Tensor to numpy
img_arr = img_arr.astype('uint8')  #convert Float to Int
print(img_arr.shape)                #[C,H,W]
img_new = np.transpose(img_arr, (1, 2, 0))   #use np.transpose() convert [C,H,W] to [H,W,C]
plt.imshow(img_new)
plt.show()
plt (800, 600, 3)
torch.Size([3, 800, 600])
(3, 800, 600)

在这里插入图片描述

3.3 Normalize

c h a n n e l = c h a n n e l − m e a n s t d channel = \frac{channel - mean}{std} channel=stdchannelmean进行规范化。(是对tensor进行归一化,所以需要放在transforms.ToTensor()之后)

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
# 这两组值是 ImageNet数据集大样本统计得出的

#归一化 
transform2 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
    ]
)
3.4 compose
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# 先转PIL 再进入Compose 进行数据增强
all_transforms = transforms.Compose([
                    transforms.Resize(256),
                    transforms.RandomSizedCrop(224),
                    transforms.RandomHorizontalFlip(), # 对PIL.Image图片进行操作
                    transforms.ToTensor(),
                    normalize])
# 或者ToTensor之后 再转PIL
transform2 = transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.RandomCrop((300,300)),
])
img = Image.open(img_path).convert('RGB')
 
img2 = transform2(img)
 
img2.show()

在这里插入图片描述

Reference:

数据来源:爱分割 github

https://blog.csdn.net/tsq292978891/article/details/78767326


  1. Image data types and what they mean ↩︎

  • 16
    点赞
  • 59
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 18
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

烤粽子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值