【matplotlib + opencv】关于opencv和matplotlib绘制图像时,出现色差色偏的问题探讨,思考,解决。(深度学习数据包plt.imshow绘制的图像底色偏绿蓝偏黄)

一、图像红变蓝,蓝变红的问题

(1)原因分析

  • 用cv2.imread()读取数据,用plt.imshow()展示数据会出现红变蓝,蓝变红的问题。
  • 原因:cv2.imread() 读取图像格式为b,g,r(这是由于以前流行bgr格式的图像显示方式,今年才流行rgb格式,opencv的这个格式是历史遗留问题)而 plt.imshow()显示按照 rgb次序。因此会出现色偏,应该使用
  • RGB 和 BGR 的转换可以直接使用 cvt = org[:,:,::-1] 来实现,前两个 : 表示第一第二维不变,::-1表示将第三维倒序排列。

(2)代码及结果展示

1)错误代码
import cv2
import matplotlib.pyplot as plt

image = cv2.imread('1.jpg')
# image = image[:, :, ::-1]
print(image.shape)
plt.imshow(image)
# 关闭x,y轴刻度
plt.xticks([])
plt.yticks([])
# 关闭坐标轴
plt.axis('off')
plt.show()

2)错误结果
  • 左边为原图,右边为plt.imshow()显示的图片,可以发现红变蓝,蓝变红的问题
3)正确代码
  • RGB 和 BGR 的转换可以直接使用 cvt = org[:,:,::-1] 来实现,前两个 : 表示第一第二维不变,::-1表示将第三维倒序排列。
import cv2
import matplotlib.pyplot as plt

image = cv2.imread('1.jpg')
# 修改的地方,bgr→rgb
image = image[:, :, ::-1]
# 或image = image[:,:,[2,1,0]]
print(image.shape)
plt.imshow(image)
# 关闭x,y轴刻度
plt.xticks([])
plt.yticks([])
# 关闭坐标轴
plt.axis('off')
plt.show()

二、深度学习数据包plt.imshow绘制图像偏蓝黄色

(1)原因分析

  • 色偏问题原因:plt.imshow()在绘制2维图像时,0(最小值)显示深蓝色,1(最大值)显示黄色,其他数值显示由蓝到黄的过度颜色,所以图像画出来是蓝黄相间的。
1)原理解释代码
# coding=utf-8
from matplotlib import pyplot as plt

X = [[0.1, 0.2, 0.3],
     [0.4, 0.5, 0.6],
     [0.7, 0.8, 0.9]]
plt.imshow(X)
plt.colorbar()
# # 关闭x,y轴刻度
# plt.xticks([])
# plt.yticks([])
# # 关闭坐标轴
# plt.axis('off')
plt.show()

2)结果

在这里插入图片描述

(2)实际绘图分析与解决

1)问题代码示例
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchvision.transforms as transforms
from IPython import display

np.set_printoptions(threshold=100000000)

mnist_train = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True,
                                                transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True,
                                               transform=transforms.ToTensor())


def use_svg_display():
    """Use svg format to display plot in jupyter"""
    display.set_matplotlib_formats('svg')


# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]


# 本函数已保存在d2lzh包中方便以后使用
def show_fashion_mnist(images, labels):
    use_svg_display()
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        # np.squeeze将(1,28,28)→(28,28)
        f.imshow(np.squeeze(img.numpy()))
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()


# 完成了torch.utils.data.DataLoader的功能
X, y = [], []
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

2)问题结果

在这里插入图片描述

3)期望代码示例
  • 绘制出黑白图像方法:将三个相同的(28,28, 1)的图像升维为(28,28, 3)的图像,然后绘制出来,就可以发现是正常的黑白图像了。
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchvision.transforms as transforms
from IPython import display

np.set_printoptions(threshold=100000000)

mnist_train = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True,
                                                transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True,
                                               transform=transforms.ToTensor())


def use_svg_display():
    """Use svg format to display plot in jupyter"""
    display.set_matplotlib_formats('svg')


# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]


# 本函数已保存在d2lzh包中方便以后使用
def show_fashion_mnist(images, labels):
    use_svg_display()
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        img = np.squeeze(img).numpy()
        # (3, 28, 28)
        image = np.array([img, img, img])
        # (3, 28, 28) → (28, 28, 3)
        f.imshow(np.transpose(image, (1, 2, 0)))
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()


# 完成了torch.utils.data.DataLoader的功能
X, y = [], []
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

4)期望结果
  • 最后终于得到了想要的黑白图像啦!
    在这里插入图片描述

三、参考

【小记】RGB图像转换为BGR-详细讲解了如何手动将RGB图像转换为BGR图像
python将两个二维array叠加成三维array的实现方法

  • 6
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值