文章目录
一、图像红变蓝,蓝变红的问题
(1)原因分析
- 用cv2.imread()读取数据,用plt.imshow()展示数据会出现红变蓝,蓝变红的问题。
- 原因:cv2.imread() 读取图像格式为b,g,r(这是由于以前流行bgr格式的图像显示方式,今年才流行rgb格式,opencv的这个格式是历史遗留问题)而 plt.imshow()显示按照 rgb次序。因此会出现色偏,应该使用
- RGB 和 BGR 的转换可以直接使用 cvt = org[:,:,::-1] 来实现,前两个 : 表示第一第二维不变,::-1表示将第三维倒序排列。
(2)代码及结果展示
1)错误代码
import cv2
import matplotlib.pyplot as plt
image = cv2.imread('1.jpg')
# image = image[:, :, ::-1]
print(image.shape)
plt.imshow(image)
# 关闭x,y轴刻度
plt.xticks([])
plt.yticks([])
# 关闭坐标轴
plt.axis('off')
plt.show()
2)错误结果
- 左边为原图,右边为plt.imshow()显示的图片,可以发现红变蓝,蓝变红的问题
3)正确代码
- RGB 和 BGR 的转换可以直接使用 cvt = org[:,:,::-1] 来实现,前两个 : 表示第一第二维不变,::-1表示将第三维倒序排列。
import cv2
import matplotlib.pyplot as plt
image = cv2.imread('1.jpg')
# 修改的地方,bgr→rgb
image = image[:, :, ::-1]
# 或image = image[:,:,[2,1,0]]
print(image.shape)
plt.imshow(image)
# 关闭x,y轴刻度
plt.xticks([])
plt.yticks([])
# 关闭坐标轴
plt.axis('off')
plt.show()
二、深度学习数据包plt.imshow绘制图像偏蓝黄色
(1)原因分析
- 色偏问题原因:plt.imshow()在绘制2维图像时,0(最小值)显示深蓝色,1(最大值)显示黄色,其他数值显示由蓝到黄的过度颜色,所以图像画出来是蓝黄相间的。
1)原理解释代码
# coding=utf-8
from matplotlib import pyplot as plt
X = [[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]]
plt.imshow(X)
plt.colorbar()
# # 关闭x,y轴刻度
# plt.xticks([])
# plt.yticks([])
# # 关闭坐标轴
# plt.axis('off')
plt.show()
2)结果
(2)实际绘图分析与解决
- 下图每一张图像是由(28,28, 1)压缩(np.squeeze)到(28,28)的异色图像,
- 以下代码以FashionMNIST数据集为例,首先读取10个(28,28)的图片展示在一排,图片颜色维蓝黄色。
- 相关博客:【pytorch + matplotlib】将若干张图像拼接成一张图像(附代码,以FashionMNIST为例)(subplot 和 subplots区别)
1)问题代码示例
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchvision.transforms as transforms
from IPython import display
np.set_printoptions(threshold=100000000)
mnist_train = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True,
transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True,
transform=transforms.ToTensor())
def use_svg_display():
"""Use svg format to display plot in jupyter"""
display.set_matplotlib_formats('svg')
# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
# 本函数已保存在d2lzh包中方便以后使用
def show_fashion_mnist(images, labels):
use_svg_display()
# 这里的_表示我们忽略(不使用)的变量
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
# np.squeeze将(1,28,28)→(28,28)
f.imshow(np.squeeze(img.numpy()))
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
# 完成了torch.utils.data.DataLoader的功能
X, y = [], []
for i in range(10):
X.append(mnist_train[i][0])
y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
2)问题结果
3)期望代码示例
- 绘制出黑白图像方法:将三个相同的(28,28, 1)的图像升维为(28,28, 3)的图像,然后绘制出来,就可以发现是正常的黑白图像了。
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchvision.transforms as transforms
from IPython import display
np.set_printoptions(threshold=100000000)
mnist_train = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True,
transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True,
transform=transforms.ToTensor())
def use_svg_display():
"""Use svg format to display plot in jupyter"""
display.set_matplotlib_formats('svg')
# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
# 本函数已保存在d2lzh包中方便以后使用
def show_fashion_mnist(images, labels):
use_svg_display()
# 这里的_表示我们忽略(不使用)的变量
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
img = np.squeeze(img).numpy()
# (3, 28, 28)
image = np.array([img, img, img])
# (3, 28, 28) → (28, 28, 3)
f.imshow(np.transpose(image, (1, 2, 0)))
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
# 完成了torch.utils.data.DataLoader的功能
X, y = [], []
for i in range(10):
X.append(mnist_train[i][0])
y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
4)期望结果
- 最后终于得到了想要的黑白图像啦!
三、参考
【小记】RGB图像转换为BGR-详细讲解了如何手动将RGB图像转换为BGR图像
python将两个二维array叠加成三维array的实现方法