PyTorch学习笔记——图像处理 torchvision.ToTensor
PIL.Image/numpy.ndarray与Tensor的相互转换
PIL.Image/numpy.ndarray转化为Tensor,常常用在训练模型阶段的数据读取,而Tensor转化为PIL.Image/numpy.ndarray则用在验证模型阶段的数据输出。
我们可以使用 transforms.ToTensor() 将 PIL.Image/numpy.ndarray 数据进转化为torch.FloadTensor,并归一化到[0, 1.0]:
取值范围为[0, 255]的PIL.Image,转换成形状为[C, H, W],取值范围是[0, 1.0]的torch.FloadTensor;
transforms.ToTensor()作用
ToTensor()将shape为(H, W, C)的nump.ndarray或img转为shape为(C, H, W)的tensor,其将每一个数值归一化到[0,1],其归一化方法比较简单,直接除以255即可。具体可参见如下代码:
实例
图像代码
img_path = "1.jpg"
# transforms.ToTensor()
transform1 = transforms.Compose([
transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
]
)
##numpy.ndarray
img = cv2.imread(img_path)# 读取图像
img1 = transform1(img) # 归一化到 [0.0,1.0]
print("img.shape = ",img.shape)
print("img1.shape = ",img1.shape)
print("img = ",img)
print("img1 = ",img1)
img.shape = (424, 640, 3)
img1.shape = torch.Size([3, 424, 640])
img = [[[249 246 231]
[248 246 228]
[247 245 227]
...
[ 17 19 20]
[ 20 19 21]
[ 32 14 21]]
[[249 246 231]
[248 246 228]
[247 245 227]
...
[ 29 30 20]
[ 29 30 20]
[ 15 24 14]]
[[248 245 230]
[248 245 230]
[248 246 228]
...
[ 80 41 26]
[ 80 41 26]
[ 45 36 16]]
...
[[169 175 216]
[174 178 219]
[167 173 210]
...
[ 14 17 22]
[ 14 17 22]
[ 15 18 23]]
[[151 157 202]
[163 167 208]
[172 174 209]
...
[ 14 17 22]
[ 13 16 21]
[ 13 16 21]]
[[168 171 215]
[173 174 212]
[179 177 207]
...
[ 13 16 21]
[ 12 15 20]
[ 11 13 21]]]
img1 = tensor([[[0.9765, 0.9725, 0.9686, ..., 0.0667, 0.0784, 0.1255],
[0.9765, 0.9725, 0.9686, ..., 0.1137, 0.1137, 0.0588],
[0.9725, 0.9725, 0.9725, ..., 0.3137, 0.3137, 0.1765],
...,
[0.6627, 0.6824, 0.6549, ..., 0.0549, 0.0549, 0.0588],
[0.5922, 0.6392, 0.6745, ..., 0.0549, 0.0510, 0.0510],
[0.6588, 0.6784, 0.7020, ..., 0.0510, 0.0471, 0.0431]],
[[0.9647, 0.9647, 0.9608, ..., 0.0745, 0.0745, 0.0549],
[0.9647, 0.9647, 0.9608, ..., 0.1176, 0.1176, 0.0941],
[0.9608, 0.9608, 0.9647, ..., 0.1608, 0.1608, 0.1412],
...,
[0.6863, 0.6980, 0.6784, ..., 0.0667, 0.0667, 0.0706],
[0.6157, 0.6549, 0.6824, ..., 0.0667, 0.0627, 0.0627],
[0.6706, 0.6824, 0.6941, ..., 0.0627, 0.0588, 0.0510]],
[[0.9059, 0.8941, 0.8902, ..., 0.0784, 0.0824, 0.0824],
[0.9059, 0.8941, 0.8902, ..., 0.0784, 0.0784, 0.0549],
[0.9020, 0.9020, 0.8941, ..., 0.1020, 0.1020, 0.0627],
...,
[0.8471, 0.8588, 0.8235, ..., 0.0863, 0.0863, 0.0902],
[0.7922, 0.8157, 0.8196, ..., 0.0863, 0.0824, 0.0824],
[0.8431, 0.8314, 0.8118, ..., 0.0824, 0.0784, 0.0824]]])
代码解释
(H, W, C)的nump.ndarray或img转为shape为(C, H, W)
代码中将img的 ([3, 424, 640]) 转化为 (424, 640, 3)
最后通过transforms.ToTensor,数据归一化在0到1之间
逆过程
代码
# 转化为numpy.ndarray并显示
img_1 = img1.numpy()*255
print(img_1)
# 转化为uint8类型(8位无符号整型)
img_1 = img_1.astype('uint8')
print(img_1)
img_1 = np.transpose(img_1, (1,2,0))
print(img_1)
plt.imshow(img_1)
plt.show()
print("\n-----------\n")
org = imgplt.imread('1.jpg')
plt.imshow(org)
plt.show()
[[[249. 248. 247. ... 17. 20. 32.]
[249. 248. 247. ... 29. 29. 15.]
[248. 248. 248. ... 80. 80. 45.]
...
[169. 174. 167. ... 14. 14. 15.]
[151. 163. 172. ... 14. 13. 13.]
[168. 173. 179. ... 13. 12. 11.]]
[[246. 246. 245. ... 19. 19. 14.]
[246. 246. 245. ... 30. 30. 24.]
[245. 245. 246. ... 41. 41. 36.]
...
[175. 178. 173. ... 17. 17. 18.]
[157. 167. 174. ... 17. 16. 16.]
[171. 174. 177. ... 16. 15. 13.]]
[[231. 228. 227. ... 20. 21. 21.]
[231. 228. 227. ... 20. 20. 14.]
[230. 230. 228. ... 26. 26. 16.]
...
[216. 219. 210. ... 22. 22. 23.]
[202. 208. 209. ... 22. 21. 21.]
[215. 212. 207. ... 21. 20. 21.]]]
[[[249 248 247 ... 17 20 32]
[249 248 247 ... 29 29 15]
[248 248 248 ... 80 80 45]
...
[169 174 167 ... 14 14 15]
[151 163 172 ... 14 13 13]
[168 173 179 ... 13 12 11]]
[[246 246 245 ... 19 19 14]
[246 246 245 ... 30 30 24]
[245 245 246 ... 41 41 36]
...
[175 178 173 ... 17 17 18]
[157 167 174 ... 17 16 16]
[171 174 177 ... 16 15 13]]
[[231 228 227 ... 20 21 21]
[231 228 227 ... 20 20 14]
[230 230 228 ... 26 26 16]
...
[216 219 210 ... 22 22 23]
[202 208 209 ... 22 21 21]
[215 212 207 ... 21 20 21]]]
[[[249 246 231]
[248 246 228]
[247 245 227]
...
[ 17 19 20]
[ 20 19 21]
[ 32 14 21]]
[[249 246 231]
[248 246 228]
[247 245 227]
...
[ 29 30 20]
[ 29 30 20]
[ 15 24 14]]
[[248 245 230]
[248 245 230]
[248 246 228]
...
[ 80 41 26]
[ 80 41 26]
[ 45 36 16]]
...
[[169 175 216]
[174 178 219]
[167 173 210]
...
[ 14 17 22]
[ 14 17 22]
[ 15 18 23]]
[[151 157 202]
[163 167 208]
[172 174 209]
...
[ 14 17 22]
[ 13 16 21]
[ 13 16 21]]
[[168 171 215]
[173 174 212]
[179 177 207]
...
[ 13 16 21]
[ 12 15 20]
[ 11 13 21]]]
代码解释
上面的显示有点多,一个个解释
img_1 = img1.numpy()*255
这样0到1的数组就回到原有的大小
img_1 = img_1.astype(‘uint8’)
就是将数据从 浮点数变为整数
img_1 = np.transpose(img_1, (1,2,0))
就是将数组的行列值的索引值改变
我在之前的博客详细说明了transpose的用法
最后看看原图和经过torchvision.ToTensor变化又逆变化回去的图的对比
原图这个样子
变化回去的图成这样了,颜色突变了,其他都没问题
这是因为 cv.imread 读取图像格式为b,g,r,但是 plt显示按照 rgb次序
关于 处理opencv里用plt显示imread读取图像偏色问题 ,可以参考我之前的博客
所以我们需要对代码进行调整如下
##numpy.ndarray
img = cv2.imread(img_path)# 读取图像
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
在cv2.imread后面加入这两行后,问题就解决了,我们看看输出
可以看出,原图和我们处理后的图一样了。