零基础入门CV赛事-Task2

2 数据读取与数据扩增

这次从定长字符识别来构建模型
参考:https://github.com/datawhalechina/team-learning/blob/master/03%20%E8%AE%A1%E7%AE%97%E6%9C%BA%E8%A7%86%E8%A7%89/%E8%AE%A1%E7%AE%97%E6%9C%BA%E8%A7%86%E8%A7%89%E5%AE%9E%E8%B7%B5%EF%BC%88%E8%A1%97%E6%99%AF%E5%AD%97%E7%AC%A6%E7%BC%96%E7%A0%81%E8%AF%86%E5%88%AB%EF%BC%89/Datawhale%20%E9%9B%B6%E5%9F%BA%E7%A1%80%E5%85%A5%E9%97%A8CV%20-%20Task%2002%20%E6%95%B0%E6%8D%AE%E8%AF%BB%E5%8F%96%E4%B8%8E%E6%95%B0%E6%8D%AE%E6%89%A9%E5%A2%9E.md

图像读取

Python中有比较常见的有Pillow和OpenCV。
Pillow:

from PIL import Image
# 打开一个png图像文件,注意是当前路径:
im1 = Image.open('111.png')
im2 = Image.open('222.png')
# 应用模糊滤镜,保存文件名为‘blur.jpg’
im1 = im.filter(ImageFilter.BLUR)
im1.save('blur.jpg', 'jpeg')
# 图像缩小
h = im2.size[0]
w = im2.size[1]
im2.thumbnail((h//3, w//3))
im2.save('blur2.jpg', 'jpeg')
im1.show()
im2.show()

OpenCV:
OpenCV在功能上比Pillow更加强大很多,学习成本也高很多.

import cv2
img = cv2.imread('111.png')
# 转换为灰度图  
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 
#边缘检测
#第二、三个参数分别是两个阈值,上限和下限
edges = cv2.Canny(img, 30, 70)
cv2.imshow("Canny边缘检测", edges)   
cv2.waitKey (0)  
cv2.imwrite('222.jpg',edges)

OpenCV官网:https://opencv.org/
OpenCV Github:https://github.com/opencv/opencv
OpenCV 扩展算法库:https://github.com/opencv/opencv_contrib

数据扩增

在常见的数据扩增方法中,一般会从图像颜色、尺寸、形态、空间和像素等角度进行变换。
torchvision:
pytorch官方提供的数据扩增库,提供了基本的数据数据扩增方法,可以无缝与torch进行集成;但数据扩增方法种类较少,且速度中等;
imgaug:
mgaug是常用的第三方数据扩增库,提供了多样的数据扩增方法,且组合起来非常方便,速度较快;
albumentations:
是常用的第三方数据扩增库,提供了多样的数据扩增方法,对图像分类、语义分割、物体检测和关键点检测都支持,速度较快。

以torchvision为例,常见的数据扩增方法包括:
中心裁剪

torchvision.transforms.CenterCrop(size)

随机水平翻转

torchvision.transforms.RandomHorizontalFlip(p=0.5)
#p- 概率,默认值为 0.5

亮度对比度饱和度变换

torchvision.transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0)
# 亮度 —— [max(0, 1 - brightness), 1 + brightness]或给定[min,max]
# 对比度 —— [max(0, 1 - contrast), 1 + contrast]或给定[min,max]
# 饱和度 —— [max(0, 1 - saturation), 1 + saturation]或给定[min,max]
# 色调 —— [-hue, hue] && 0<=hue<= 0.5或-0.5 <= min <= max <= 0.5

转灰度图

torchvision.transforms.RandomGrayscale(p=0.1)
#依概率p将图片转换为灰度图

等等

Pytorch读取数据

这里只需要重载一下数据读取的逻辑就可以完成数据的读取。

import os, sys, glob, shutil, json
import cv2

from PIL import Image
import numpy as np

import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
#类
class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 原始SVHN中类别10为数字0
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]
#数据扩增
data = SVHNDataset(train_path, train_label,
          transforms.Compose([
              # 缩放到固定尺寸
              transforms.Resize((64, 128)),
              # 随机颜色变换
              transforms.ColorJitter(0.2, 0.2, 0.2),
              # 加入随机旋转
              transforms.RandomRotation(5),

              # 将图片转换为pytorch 的tesntor
              # transforms.ToTensor(),

              # 对图像像素进行归一化
              # transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ]))

接下来在定义好的Dataset基础上构建DataLoder。
Dataset:对数据集的封装,提供索引方式的对数据样本进行读取
DataLoder:对Dataset进行封装,提供批量读取的迭代读取
加入DataLoder后,数据读取代码改为如下:(部分重合代码省略):

train_loader = torch.utils.data.DataLoader(
        SVHNDataset(train_path, train_label,
                   transforms.Compose([
                       transforms.Resize((64, 128)),
                       transforms.ColorJitter(0.3, 0.3, 0.2),
                       transforms.RandomRotation(5),
                       transforms.ToTensor(),
                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])), 
    batch_size=10, # 每批样本个数
    shuffle=False, # 是否打乱顺序
    num_workers=10, # 读取的线程个数
)
for data in train_loader:
    break

总结

学习了数据读取与常见的数据扩增方法和使用,以及使用Pytorch框架对本次赛题的数据进行读取。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值