零基础入门CV赛事(二)

数据读取与数据扩增

从这一节开始都会使用【定长识别字符】的思路去解决问题

本章主要内容为数据读取数据扩增方法Pytorch读取赛题数据三个部分组成

学习目标

  • 学习Python和Pytorch中图像读取
  • 学会扩增方法和Pytorch读取赛题数据

图像读取

既然是字符识别那首要的任务就是先完成对所需要识别的图像的读取

比较常见的通过python读取的包有PillowOpenCV,这里仅列出读取方法,更多操作见对应手册

1. Pillow

pillow读取图像方式

from PIL import Image

im =Image.open(image_path)
im_gray = Image.open(image_path).convert("L")

2. OpenCV

OpenCV读取图像方式

import cv2

img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Opencv默认颜色通道顺序是BRG,转换一下
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 

数据扩增

  • 数据扩增介绍
    数据扩增可以增加训练集的样本,同时也可以有效缓解模型过拟合的情况,也可以给模型带来的更强的泛化能力。直观上看就是让数据集更加大

    在这里插入图片描述

  • 数据扩增为什么有用
    差不多的数据,多出一些,有作用吗?

    在深度学习模型的训练过程中,数据扩增是必不可少的环节。现有深度学习的参数非常多,一般的模型可训练的参数量基本上都是万到百万级别,而训练集样本的数量很难有这么多

    其次数据扩增可以扩展样本空间,假设现在的分类模型需要对汽车进行分类,左边的是汽车A,右边为汽车B。如果不使用任何数据扩增方法,深度学习模型会从汽车车头的角度来进行判别,而不是汽车具体的区别

    在这里插入图片描述

    数据扩增的前提是有足够的原始数据集,如果我原始的数据就几辆车,那也起不到效果

  • 有哪些数据扩增方法?
    数据扩增方法有很多:从颜色空间、尺度空间到样本空间,同时根据不同任务数据扩增都有相应的区别

    对于图像分类,数据扩增一般不会改变标签;对于物体检测,数据扩增会改变物体坐标位置;对于图像分割,数据扩增会改变像素标签

1. 常见的数据扩增方法

在常见的数据扩增方法中,一般会从图像颜色、尺寸、形态、空间和像素等角度进行变换。当然不同的数据扩增方法可以自由进行组合,得到更加丰富的数据扩增方法。

torchvision为例,常见的数据扩增方法包括:

  • transforms.CenterCrop 对图片中心进行裁剪
  • transforms.ColorJitter 对图像颜色的对比度、饱和度和零度进行变换
  • transforms.FiveCrop 对图像四个角和中心进行裁剪得到五分图像
  • transforms.Grayscale 对图像进行灰度变换
  • transforms.Pad 使用固定值进行像素填充
  • transforms.RandomAffine 随机仿射变换
  • transforms.RandomCrop 随机区域裁剪
  • transforms.RandomHorizontalFlip 随机水平翻转
  • transforms.RandomRotation 随机旋转
  • transforms.RandomVerticalFlip 随机垂直翻转

在这里插入图片描述

2. 常用的数据扩增库

  • torchvision
    https://github.com/pytorch/vision
    pytorch官方提供的数据扩增库,提供了基本的数据数据扩增方法,可以无缝与torch进行集成;但数据扩增方法种类较少,且速度中等;

  • imgaug
    https://github.com/aleju/imgaug
    imgaug是常用的第三方数据扩增库,提供了多样的数据扩增方法,且组合起来非常方便,速度较快;

  • albumentations
    https://albumentations.readthedocs.io
    是常用的第三方数据扩增库,提供了多样的数据扩增方法,对图像分类、语义分割、物体检测和关键点检测都支持,速度较快

Pytorch读取数据

在Pytorch中数据是通过Dataset进行封装,并通过DataLoder进行并行读取。所以我们只需要重载一下数据读取的逻辑就可以完成数据的读取

import os, sys, glob, shutil, json
import cv2

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 原始SVHN中类别10为数字0
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (6 - len(lbl)) * [10]
        
        return img, torch.from_numpy(np.array(lbl[:6]))

    def __len__(self):
        return len(self.img_path)

train_path = glob.glob('mchar_train/*.png')
train_path.sort()
train_json = json.load(open('mchar_train.json'))
train_label = [train_json[x]['label'] for x in train_json]

data = SVHNDataset(train_path, train_label,
          transforms.Compose([
              # 缩放到固定尺寸
              transforms.Resize((64, 128)),

              # 随机颜色变换
              transforms.ColorJitter(0.2, 0.2, 0.2),

              # 加入随机旋转
              transforms.RandomRotation(5),

              # 将图片转换为pytorch 的tesntor
              # transforms.ToTensor(),

              # 对图像像素进行归一化
              # transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ]))

代码理解:

  1. SVHNDataset这是建立我们数据集的一个类,继承于Dataset
  2. SVHNDataset的初始化需要三个参数,分别为数据集标签变换方式
  3. SVHNDataset__getitem__方法使得数据集可以进行迭代或通过下标进行索引,
  4. SVHNDataset__len__方法使得可以查看数据集总量
  5. 数据集和标签集的加载部分
    train_path获得的是一个以图片名为内容的列表
    train_label获得的是包含每张图片所有数字的列表的列表
  6. 最后的生成data的语句,指定了数据集、标签,并引入尺寸、颜色、角度三个变换
    可以通过data[?][0]的方式查看图片
    可以通过data[?][1]的方式查看图片对应的标签

例如我们生成了data、data1、data2三个数据集,分别取了对应的图片看看不同的变换效果

1 2 3
在这里插入图片描述 在这里插入图片描述 在这里插入图片描述
在这里插入图片描述 在这里插入图片描述 在这里插入图片描述

接下来我们将在定义好的Dataset基础上构建DataLoder,有了Dataset为什么还要有DataLoder?其实这两个是两个不同的概念,是为了实现不同的功能。

  • Dataset:对数据集的封装,提供索引方式的对数据样本进行读取
  • DataLoder:对Dataset进行封装,提供批量读取的迭代读取

就像我是卖饼干的,用Dataset把饼干一包包都装好,但是我卖出去的时候是一箱一箱卖的,而且箱子的大小可以不一样,卖的摊位的数量也可以设置,卖的顺序可以和出厂顺序不同等

加入DataLoder后,数据读取代码改为如下:

import os, sys, glob, shutil, json
import cv2

from PIL import Image
import numpy as np

import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 原始SVHN中类别10为数字0
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]

train_loader = torch.utils.data.DataLoader(
        SVHNDataset(train_path, train_label,
                   transforms.Compose([
                       transforms.Resize((64, 128)),
                       transforms.ColorJitter(0.3, 0.3, 0.2),
                       transforms.RandomRotation(5),
                       transforms.ToTensor(),
                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])), 
    batch_size=10, # 每批样本个数
    shuffle=False, # 是否打乱顺序
    num_workers=10, # 读取的线程个数
)

其实就是在原来的基础上加了这样一段代码

train_loader = torch.utils.data.DataLoader(
        SVHNDataset(train_path, train_label,
                   transforms.Compose([
                       transforms.Resize((64, 128)),
                       transforms.ColorJitter(0.3, 0.3, 0.2),
                       transforms.RandomRotation(5),
                       transforms.ToTensor(),
                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])), 
    batch_size=10, # 每批样本个数
    shuffle=False, # 是否打乱顺序
    # num_workers=10, # 读取的线程个数 Linux
    num_workers=0, # 读取的线程个数 Win
)

可以通过以下语句来查看我们的train_loader,仅查看第一个批量

for batch in train_loader:
    print(batch[0].shape)
    print(batch[1].shape)
    break
torch.Size([10, 3, 64, 128])
torch.Size([10, 6])

前者为图像文件,为batchsize * chanel * height * width次序;后者为字符标签

这里的6和前面数据集的字符填充的总长度设置有关,我把5改成了6

本章小节

本章对数据读取进行了详细的讲解,并介绍了常见的数据扩增方法和使用,最后使用Pytorch框架对本次赛题的数据进行读取

展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 游动-白 设计师: 上身试试
应支付0元
点击重新获取
扫码支付

支付成功即可阅读