天池CV赛事-街景字符编码识别(二)—— 数据读取与数据扩增


系列文章
天池CV赛事-街景字符编码识别(一) —— 赛题理解
天池CV赛事-街景字符编码识别(二) —— 数据读取与数据扩增
天池CV赛事-街景字符编码识别(三)—— 字符识别模型
天池CV赛事-街景字符编码识别(四)—— 模型训练与验证
天池CV赛事-街景字符编码识别(五)—— 模型集成


二、数据读取与数据扩增

2.1 图像读取

Python中读取图像常用的库有Pillow和OpenCV

Pillow

Pillow是图像处理函式库(PIL)的一个分支。Pillow提供了常见的图像读取和处理的操作,而且可以与ipython notebook无缝集成,是应用比较广泛的库。

Pillow的官方文档:https://pillow.readthedocs.io/en/stable/

常见的函数有:

from PIL import Image, ImageFilter —— 导入Pillow库
img = Image.open(‘lena.png’) —— 读取图片
img2 = im.filter(ImageFilter.BLUR) —— 模糊
img2.save(‘blur.jpg’, ‘jpeg’) —— 保存图片
img.thumbnail((w//2, h//2)) —— 缩放
img.save(‘thumbnail.jpg’, ‘jpeg’)

OpenCV

Pillow只是个基础的图像处理库,OpenCV是更好的选择。OpenCV拥有更多的计算机视觉,数组图像处理和机器视觉等功能,但是学习起来也比Pillow更复杂一点。

对于常用的图像处理功能,可以到上一阶段的Python+OpenCV图像处理基础知识中了解。
这里简单介绍几个:

import cv2 —— 导入Opencv库
img = cv2.imread(‘cat.jpg’) —— 读取图片
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) —— 色彩空间转换,OpenCV中通道为BRG
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) —— 转换为灰度图
edges = cv2.Canny(img, 30, 70) —— Canny边缘检测
cv2.imwrite(‘canny.jpg’, edges) —— 保存图片

OpenCV官网:https://opencv.org/
OpenCV Github:https://github.com/opencv/opencv
OpenCV 扩展算法库:https://github.com/opencv/opencv_contrib

常用数据集读取
Pytorch的torchvision中包含了很多常用的数据集,例如Imagene,MNIST,CIFAR10,VOC,SVHN

对于常用的数据集,可以通过torchvision.datasets读取,所有datasets继承torch.utils.data.Dataset,也就是说,它们实现了_getitem_和_len_方法。

pytorch支持的常用数据加载,可以参考:
http://pytorch.org/docs/stable/torchvision/datasets.html

所有datasets读取方法的API基本类似,以CIFAR10为例:

torchvision.datasets.CIFAR10(root ,train=True,transform=None,target_transform=None,download=False)

参数:

root:存放数据集的路径
train:True为从训练集创建数据集,否则从测试集创建
transform:数据预处理,如transforms.RandomRotation
target_transform:标注的预处理
download:是否下载,若为True则从互联网下载,如果在root已经存在,就不会再次下载

数据装载
数据装载,加快我们准备数据集的速度。torch.utils.data.DataLoader对Dataset进行了封装,可以利用DataLoader进行多线程批量读取。

from PIL import Image
import torch
import torchvision
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

train_data=torchvision.datasets.CIFAR10('./test/dataset',train=True,
                                                    transform=None,
                                                    target_transform=None,
                                                    download=True)
train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=2,#每一批读取图像个数
                                           shuffle=True,
                                           num_workers=0)#读取数据的线程数,win改成0,linux没问题

输出:
28%|██▊ | 48144384/170498071 [01:08<02:13, 913466.03it/s]
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./test/dataset\cifar-10-python.tar.gz

Pytorch自定义数据集读取

主要涉及两个类:

  • torch.utils.data.Dataset
  • torch.utils.data.DataLoader

Dataset类的基本结构:需要实现_getitem_和_len_方法

class MyDataset(Dataset):#继承Dataset
    def __init__(self, img_path, img_label, transform=None):
        #初始化文件路径或文件名列表,还有是否进行图像增强
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        #1,从文件中读取一个数据(例如,使用PIL.Image.open,cv2.imread)
        #2,预处理数据(例如torchvision.Transform)
        #3,返回数据对(例如图像和标签)
        pass

    def __len__(self):
        return count

2.2 数据扩增

在深度学习中数据扩增方法非常重要,数据扩增可以增加训练集的样本,同时也可以有效缓解模型过拟合的情况,也可以给模型带来的更强的泛化能力。

2.2.1 常见的数据扩增方法

两种数据扩增方法:

基于图像处理的数据扩增

  • 几何变换:旋转、缩放、翻转、裁剪、平移、仿射变换

  • 灰度和彩色空间变换:亮度调整、对比度、饱和度调整、彩色空间转换、色彩调整(对抗数据中存在的光照等偏差)

  • 添加噪声和滤波:注入高斯噪声、椒盐噪声等、滤波(模糊、锐化)(应对噪声干扰)
    ps:baseline过拟合

  • Mixing images(图像混合)(目标检测领域比较有用)

  • Random erasing(随机擦除)(应对遮挡)

基于深度学习的数据扩增
(比较耗费资源,学术上用的多,比赛用得少)

  • 基于GAN的数据增强(GAN-based Data Augmentation):使用GAN生成模型来生成更多的数据,可用作解决类别不平衡问题的过采样技术
  • 神经风格转换(Neural Style Transfer):通过神经网络风格迁移来生成不同风格的数据,防止模型过拟合
  • AutoAugment

推荐论文阅读:
《A survey on Image Data Augmentation for Deep Learning》-2019
https://link.springer.com/article/10.1186/s40537-019-0197-0#Sec3

在常见的数据扩增方法中,一般会从图像颜色、尺寸、形态、空间和像素等角度进行变换。当然不同的数据扩增方法可以自由进行组合,得到更加丰富的数据扩增方法。

Pytorch中,常见的数据扩增函数主要集成在了torchvision的transforms中,这里列出19种:

裁剪

  • transforms.CenterCrop 对图片中心进行裁剪
  • transforms.RandomCrop 随机区域裁剪
  • transforms.RandomResizeCrop 随机长款比裁剪
  • transforms.FiveCrop 对图像四个角和中心进行裁剪得到五分图像
  • transforms.TenCrop 上下左右中心裁剪后翻转

翻转和旋转

  • transforms.RandomHorizontalFlip 依概率随机水平翻转
  • transforms.RandomRotation 随机旋转
  • transforms.RandomVerticalFlip 依概率随机垂直翻转

图像变换

  • transforms.Pad 使用固定值进行像素填充
  • transforms.ColorJitter 对图像颜色的对比度、饱和度和零度进行变换
  • transforms.Grayscale 对图像进行灰度变换
  • transforms.RandomGrayscale 依概率灰度化
  • transforms.RandomAffine 随机仿射变换
  • transforms.LinearTransformation 线性变换
  • transforms.RandomErasing 随机选择图像中的矩形区域并擦除其像素
  • transforms.Lambda 用户自定义变换
  • transforms.Rasize 尺度缩放
  • transforms.Totensor 将PIL Image 或者numpy.ndarray格式的数据转换成tensor
  • transforms.Normalize 图像标准化(一般前面是Totensor)
2.2.2 常用的数据扩增库

torchvision

https://github.com/pytorch/vision
pytorch官方提供的数据扩增库,提供了基本的数据数据扩增方法,可以无缝与torch进行集成;但数据扩增方法种类较少,且速度中等;

imgaug

https://github.com/aleju/imgaug
imgaug是常用的第三方数据扩增库,提供了多样的数据扩增方法,且组合起来非常方便,速度较快;

albumentations

https://albumentations.readthedocs.io
是常用的第三方数据扩增库,提供了多样的数据扩增方法,对图像分类、语义分割、物体检测和关键点检测都支持,速度较快。

2.3 实现

在Pytorch中数据是通过Dataset进行封装,并通过DataLoder进行并行读取。所以我们只需要重载一下数据读取的逻辑就可以完成数据的读取。

有了Dataset为什么还要有DataLoder?其实这两个是两个不同的概念,是为了实现不同的功能。

  • Dataset:对数据集的封装,提供索引方式的对数据样本进行读取
  • DataLoder:对Dataset进行封装,提供批量读取的迭代读取
import os, sys, glob, shutil, json
import cv2

from PIL import Image
import numpy as np

import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 原始SVHN中类别10为数字0
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

train_path = glob.glob('./mchar_train/mchar_train/*.png') #获取训练集全体路径 返回list
train_path.sort() #进行排序
train_json = json.load(open('./mchar_train.json')) #返回dict
train_label = [train_json[x]['label'] for x in train_json] #返回list

train_loader = torch.utils.data.DataLoader(
        SVHNDataset(train_path, train_label,
                   transforms.Compose([
                       transforms.Resize((64, 128)),
                       transforms.ColorJitter(0.3, 0.3, 0.2),
                       transforms.RandomRotation(5),
                       transforms.ToTensor(),
                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])), 
    batch_size=10, # 每批样本个数
    shuffle=False, # 是否打乱顺序
    num_workers=0, # win设为0,线程数
)

for data in train_loader:
    break

在加入DataLoder后,数据按照批次获取,每批次调用Dataset读取单个样本进行拼接。此时data的格式为:
torch.Size([10, 3, 64, 128]), torch.Size([10, 6])
前者为图像文件,为batchsize, chanel, height, width次序;后者为字符标签。

推荐阅读:
Python+OpenCV图像处理基础知识

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值