PyTorch学习笔记——图像处理(torchvision.ToTensor)

PIL.Image/numpy.ndarray与Tensor的相互转换

PIL.Image/numpy.ndarray转化为Tensor,常常用在训练模型阶段的数据读取,而Tensor转化为PIL.Image/numpy.ndarray则用在验证模型阶段的数据输出。

我们可以使用 transforms.ToTensor() 将 PIL.Image/numpy.ndarray 数据进转化为torch.FloadTensor,并归一化到[0, 1.0]:

取值范围为[0, 255]的PIL.Image,转换成形状为[C, H, W],取值范围是[0, 1.0]的torch.FloadTensor;

transforms.ToTensor()作用

ToTensor()将shape为(H, W, C)的nump.ndarray或img转为shape为(C, H, W)的tensor,其将每一个数值归一化到[0,1],其归一化方法比较简单,直接除以255即可。具体可参见如下代码:

实例

图像代码

img_path = "1.jpg"

# transforms.ToTensor()
transform1 = transforms.Compose([
    transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
    ]
)

##numpy.ndarray
img = cv2.imread(img_path)# 读取图像
img1 = transform1(img) # 归一化到 [0.0,1.0]
print("img.shape = ",img.shape)
print("img1.shape = ",img1.shape)

print("img = ",img)
print("img1 = ",img1)
img.shape =  (424, 640, 3)
img1.shape =  torch.Size([3, 424, 640])
img =  [[[249 246 231]
  [248 246 228]
  [247 245 227]
  ...
  [ 17  19  20]
  [ 20  19  21]
  [ 32  14  21]]

 [[249 246 231]
  [248 246 228]
  [247 245 227]
  ...
  [ 29  30  20]
  [ 29  30  20]
  [ 15  24  14]]

 [[248 245 230]
  [248 245 230]
  [248 246 228]
  ...
  [ 80  41  26]
  [ 80  41  26]
  [ 45  36  16]]

 ...

 [[169 175 216]
  [174 178 219]
  [167 173 210]
  ...
  [ 14  17  22]
  [ 14  17  22]
  [ 15  18  23]]

 [[151 157 202]
  [163 167 208]
  [172 174 209]
  ...
  [ 14  17  22]
  [ 13  16  21]
  [ 13  16  21]]

 [[168 171 215]
  [173 174 212]
  [179 177 207]
  ...
  [ 13  16  21]
  [ 12  15  20]
  [ 11  13  21]]]
img1 =  tensor([[[0.9765, 0.9725, 0.9686,  ..., 0.0667, 0.0784, 0.1255],
         [0.9765, 0.9725, 0.9686,  ..., 0.1137, 0.1137, 0.0588],
         [0.9725, 0.9725, 0.9725,  ..., 0.3137, 0.3137, 0.1765],
         ...,
         [0.6627, 0.6824, 0.6549,  ..., 0.0549, 0.0549, 0.0588],
         [0.5922, 0.6392, 0.6745,  ..., 0.0549, 0.0510, 0.0510],
         [0.6588, 0.6784, 0.7020,  ..., 0.0510, 0.0471, 0.0431]],

        [[0.9647, 0.9647, 0.9608,  ..., 0.0745, 0.0745, 0.0549],
         [0.9647, 0.9647, 0.9608,  ..., 0.1176, 0.1176, 0.0941],
         [0.9608, 0.9608, 0.9647,  ..., 0.1608, 0.1608, 0.1412],
         ...,
         [0.6863, 0.6980, 0.6784,  ..., 0.0667, 0.0667, 0.0706],
         [0.6157, 0.6549, 0.6824,  ..., 0.0667, 0.0627, 0.0627],
         [0.6706, 0.6824, 0.6941,  ..., 0.0627, 0.0588, 0.0510]],

        [[0.9059, 0.8941, 0.8902,  ..., 0.0784, 0.0824, 0.0824],
         [0.9059, 0.8941, 0.8902,  ..., 0.0784, 0.0784, 0.0549],
         [0.9020, 0.9020, 0.8941,  ..., 0.1020, 0.1020, 0.0627],
         ...,
         [0.8471, 0.8588, 0.8235,  ..., 0.0863, 0.0863, 0.0902],
         [0.7922, 0.8157, 0.8196,  ..., 0.0863, 0.0824, 0.0824],
         [0.8431, 0.8314, 0.8118,  ..., 0.0824, 0.0784, 0.0824]]])

代码解释

(H, W, C)的nump.ndarray或img转为shape为(C, H, W)
代码中将img的 ([3, 424, 640]) 转化为 (424, 640, 3)
最后通过transforms.ToTensor,数据归一化在0到1之间

逆过程

代码

# 转化为numpy.ndarray并显示
img_1 = img1.numpy()*255
print(img_1)

# 转化为uint8类型(8位无符号整型)
img_1 = img_1.astype('uint8')
print(img_1)

img_1 = np.transpose(img_1, (1,2,0))
print(img_1)
plt.imshow(img_1)
plt.show()
print("\n-----------\n")
org = imgplt.imread('1.jpg')
plt.imshow(org)
plt.show()
[[[249. 248. 247. ...  17.  20.  32.]
  [249. 248. 247. ...  29.  29.  15.]
  [248. 248. 248. ...  80.  80.  45.]
  ...
  [169. 174. 167. ...  14.  14.  15.]
  [151. 163. 172. ...  14.  13.  13.]
  [168. 173. 179. ...  13.  12.  11.]]

 [[246. 246. 245. ...  19.  19.  14.]
  [246. 246. 245. ...  30.  30.  24.]
  [245. 245. 246. ...  41.  41.  36.]
  ...
  [175. 178. 173. ...  17.  17.  18.]
  [157. 167. 174. ...  17.  16.  16.]
  [171. 174. 177. ...  16.  15.  13.]]

 [[231. 228. 227. ...  20.  21.  21.]
  [231. 228. 227. ...  20.  20.  14.]
  [230. 230. 228. ...  26.  26.  16.]
  ...
  [216. 219. 210. ...  22.  22.  23.]
  [202. 208. 209. ...  22.  21.  21.]
  [215. 212. 207. ...  21.  20.  21.]]]
[[[249 248 247 ...  17  20  32]
  [249 248 247 ...  29  29  15]
  [248 248 248 ...  80  80  45]
  ...
  [169 174 167 ...  14  14  15]
  [151 163 172 ...  14  13  13]
  [168 173 179 ...  13  12  11]]

 [[246 246 245 ...  19  19  14]
  [246 246 245 ...  30  30  24]
  [245 245 246 ...  41  41  36]
  ...
  [175 178 173 ...  17  17  18]
  [157 167 174 ...  17  16  16]
  [171 174 177 ...  16  15  13]]

 [[231 228 227 ...  20  21  21]
  [231 228 227 ...  20  20  14]
  [230 230 228 ...  26  26  16]
  ...
  [216 219 210 ...  22  22  23]
  [202 208 209 ...  22  21  21]
  [215 212 207 ...  21  20  21]]]
[[[249 246 231]
  [248 246 228]
  [247 245 227]
  ...
  [ 17  19  20]
  [ 20  19  21]
  [ 32  14  21]]

 [[249 246 231]
  [248 246 228]
  [247 245 227]
  ...
  [ 29  30  20]
  [ 29  30  20]
  [ 15  24  14]]

 [[248 245 230]
  [248 245 230]
  [248 246 228]
  ...
  [ 80  41  26]
  [ 80  41  26]
  [ 45  36  16]]

 ...

 [[169 175 216]
  [174 178 219]
  [167 173 210]
  ...
  [ 14  17  22]
  [ 14  17  22]
  [ 15  18  23]]

 [[151 157 202]
  [163 167 208]
  [172 174 209]
  ...
  [ 14  17  22]
  [ 13  16  21]
  [ 13  16  21]]

 [[168 171 215]
  [173 174 212]
  [179 177 207]
  ...
  [ 13  16  21]
  [ 12  15  20]
  [ 11  13  21]]]

代码解释

上面的显示有点多,一个个解释

img_1 = img1.numpy()*255
这样0到1的数组就回到原有的大小

img_1 = img_1.astype(‘uint8’)
就是将数据从 浮点数变为整数

img_1 = np.transpose(img_1, (1,2,0))
就是将数组的行列值的索引值改变
我在之前的博客详细说明了transpose的用法

最后看看原图和经过torchvision.ToTensor变化又逆变化回去的图的对比
在这里插入图片描述
原图这个样子
在这里插入图片描述
变化回去的图成这样了,颜色突变了,其他都没问题
这是因为 cv.imread 读取图像格式为b,g,r,但是 plt显示按照 rgb次序
关于 处理opencv里用plt显示imread读取图像偏色问题 ,可以参考我之前的博客
所以我们需要对代码进行调整如下

##numpy.ndarray
img = cv2.imread(img_path)# 读取图像

b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])

在cv2.imread后面加入这两行后,问题就解决了,我们看看输出
在这里插入图片描述
可以看出,原图和我们处理后的图一样了。

参考

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值