【学习笔记】零基础入门CV之街道字符识别-数据读取与扩增

我们已经了解了赛题的基本情况以及三种不同的解决方案,接下来,我们将学习如何读取赛题数据并对数据进行扩增。

1.学习目标

(1)掌握Python和pytorch的图像读取方法
(2)掌握基本的数据扩增方法,学会使用pytorch读取赛题数据

2.图像读取

Python中常见的用于图像处理的库有Pillow和OpenCV。

2.1 Pillow

Pillow是Python图像处理函式库(PIL)的一个分支。Pillow提供了常见的图像读取和处理的操作,如图像旋转、图像过滤与增强等。

from PIL import Image
from PIL import ImageFilter
from PIL import ImageEnhance
import matplotlib.pyplot as plt

img = Image.open("cat.png")
print(img.format) # 输出图片基本信息
print(img.mode) #输出图片的色彩模式
print(img.size) #输出图片的宽度和高度

img_blur = img.filter(ImageFilter.BLUR) #模糊处理
img_rotate = img.rotate(30) # 旋转
img_contour = img.filter(ImageFilter.CONTOUR) # 图片的轮廓

plt.figure()
plt.subplot(2,2,1)
plt.imshow(img)
plt.subplot(2,2,2)
plt.imshow(img_blur)
plt.subplot(2,2,3)
plt.imshow(img_rotate)
plt.subplot(2,2,4)
plt.imshow(img_contour)
plt.show()	

在这里插入图片描述

图1 pillow处理图像示例

2.2 OpenCV

OpenCV是一个跨平台的计算机视觉库,功能比Pillow更加强大。OpenCV也内置了很多图像特征处理的算法,如关键点检测、边缘检测、直线检测等。详细可以查看OpenCV官网OpenCV GithubOpenCV 扩展算法库等。

import cv2
import matplotlib.pyplot as plt

img = cv2.imread("./cat.png")
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) #Opencv默认颜色通道顺序是BGR,将其转换为RGB
img_gaussian_blur = cv2.GaussianBlur(img, (5, 5), 1, 0)  # 高斯模糊
img_gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) #转换为灰度图
img_edges = cv2.Canny(img,30,70) #Canny边缘检测

plt.figure()
plt.subplot(2,2,1)
plt.imshow(img)
plt.subplot(2,2,2)
plt.imshow(img_gaussian_blur)
plt.subplot(2,2,3)
plt.imshow(img_gray,cmap = "gray")
plt.subplot(2,2,4)
plt.imshow(img_edges,cmap = "gray")
plt.show()

在这里插入图片描述

图2 OpenCV图像处理示例

3.数据扩增

数据扩增是指不实际增加原始数据,只是对原始数据做一些变换,从而创造出更多的数据,目的是增加数据量、丰富数据多样性、缓解模型过拟合的情况,提高模型的泛化能力。

数据扩增的基本原则:
(1)不能引入无关的数据
(2)扩增总是基于先验知识的,对于不同的任务和场景,数据扩增的策略也会不同。
(3)扩增后的标签保持不变

常见的数据扩增的方法有裁剪、几何变换、加入噪声、改变对比度和亮度、模糊处理等。

常用的数据扩增库有torchvision,imgaug,albumentations.

以下给出torchvision进行数据扩增的一些例子。

from PIL import Image
from torchvision import transforms as tfs
import matplotlib.pyplot as plt

im = Image.open('./cat.png')
im_hf = tfs.RandomHorizontalFlip()(im) #随机水平翻转
im_vf = tfs.RandomVerticalFlip()(im) #随机竖直翻转
im_cj = tfs.ColorJitter(hue=0.5)(im) # 随机从 -0.5 ~ 0.5 之间对颜色变化

plt.figure()
plt.subplot(2,2,1)
plt.imshow(im)
plt.subplot(2,2,2)
plt.imshow(im_hf)
plt.subplot(2,2,3)
plt.imshow(im_vf)
plt.subplot(2,2,4)
plt.imshow(im_cj)
plt.show()

在这里插入图片描述

图3 torchvision进行数据扩增示例

上面这些扩增方法都不是孤立起来用的,可以联合起来用,比如先做随机翻转,然后随机截取,再做对比度增强等等,torchvision 里面有个非常方便的函数能够将这些变化合起来,这个函数就是torchvision.transforms.Compose(),例如:

im_aug = tfs.Compose([
    tfs.Resize(150),
    tfs.RandomHorizontalFlip(),
    tfs.RandomCrop(140),
    tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5)
])

nrows = 3
ncols = 3
figsize = (8, 8)
_, figs = plt.subplots(nrows, ncols, figsize=figsize)
for i in range(nrows):
    for j in range(ncols):
        figs[i][j].imshow(im_aug(im))
        figs[i][j].axes.get_xaxis().set_visible(False)
        figs[i][j].axes.get_yaxis().set_visible(False)
plt.show()

在这里插入图片描述

图4 数据扩增示例

4.数据读取

下面,进行赛题数据的读取,在Pytorch中数据是通过Dataset进行封装,并通过DataLoder进行并行读取。我们重载一下Dataset。

import os, sys, glob, shutil, json
import cv2

from PIL import Image
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 原始SVHN中类别10为数字0
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]

data = SVHNDataset(train_path, train_label,
          transforms.Compose([
              # 缩放到固定尺寸
              transforms.Resize((64, 128)),

              # 随机颜色变换
              transforms.ColorJitter(0.2, 0.2, 0.2),

              # 加入随机旋转
              transforms.RandomRotation(5),

              # 将图片转换为pytorch 的tesntor
              # transforms.ToTensor(),

              # 对图像像素进行归一化
              # transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ]))

在定义好的Dataset基础上构建DataLoder。
Dataset与DataLoder的区别在于:
Dataset:对数据集的封装,提供索引的方式对数据样本进行读取
DataLoder:对Dataset进行封装,提供批量读取的方式对数据进行迭代读取.
读取数据代码修改为:

import os, sys, glob, shutil, json
import cv2

from PIL import Image
import numpy as np

import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 原始SVHN中类别10为数字0
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]

train_loader = torch.utils.data.DataLoader(
        SVHNDataset(train_path, train_label,
                   transforms.Compose([
                       transforms.Resize((64, 128)),
                       transforms.ColorJitter(0.3, 0.3, 0.2),
                       transforms.RandomRotation(5),
                       transforms.ToTensor(),
                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])), 
    batch_size=10, # 每批样本个数
    shuffle=False, # 是否打乱顺序
    num_workers=10, # 读取的线程个数
)

在加入DataLoder后,数据按照批次获取,每批次调用Dataset读取单个样本进行拼接。

5.小结

本节介绍了数据读取与扩增的常用方法,使用Pytorch框架对本次赛题的数据进行读取。
掌握了数据读取的方法后,接下来就可以建立字符识别模型了。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值