零基础入门CV赛事 Task02:数据读取与数据扩增

学习目标

  • 学习图像读取
  • 学习扩增图片数据集的方法
  • 学习用pytorch读取赛提数据

图像读取的方法

常用的有Pillow和opencv,我选择使用opencv,完全是因为听起来更熟悉(__) 。
OpenCV是一个跨平台的计算机视觉库,最早由Intel开源得来。OpenCV发展的非常早,拥有众多的计算机视觉、数字图像处理和机器视觉等功能。OpenCV在功能上比Pillow更加强大很多,学习成本也高很多。

#代码示例
import cv2
# 导入Opencv库
img = cv2.imread('./cat.jpg')
img =cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 转换为灰度图
# Canny边缘检测
edges = cv2.Canny(img, 30, 70)
cv2.imwrite('canny.jpg', edges)

数据扩增的方法

训练集样本往往不够多,所以用数据扩增来扩展样本空间。
以torchvision为例,常见的数据扩增方法包括:

  • transforms.CenterCrop 对图片中心进行裁剪
  • transforms.ColorJitter 对图像颜色的对比度、饱和度和零度进行变换
  • transforms.FiveCrop 对图像四个角和中心进行裁剪得到五分图像
  • transforms.Grayscale 对图像进行灰度变换
  • transforms.Pad 使用固定值进行像素填充
  • transforms.RandomAffine 随机仿射变换
  • transforms.RandomCrop 随机区域裁剪
  • transforms.RandomHorizontalFlip 随机水平翻转
  • transforms.RandomRotation 随机旋转
  • transforms.RandomVerticalFlip 随机垂直翻转

常用的数据扩增库有哪些?
torchvision、imgaug、albumentations

读取数据

用的是pytorch,

import os, sys, glob, shutil, json
import cv2

from PIL import Image
import numpy as np

import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 原始SVHN中类别10为数字0
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]

train_loader = torch.utils.data.DataLoader(
        SVHNDataset(train_path, train_label,
                   transforms.Compose([
                       transforms.Resize((64, 128)),
                       transforms.ColorJitter(0.3, 0.3, 0.2),
                       transforms.RandomRotation(5),
                       transforms.ToTensor(),
                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])), 
    batch_size=10, # 每批样本个数
    shuffle=False, # 是否打乱顺序
    num_workers=10, # 读取的线程个数
)

for data in train_loader:
    break
torch.Size([10, 3, 64, 128]), torch.Size([10, 6]) 

总结

对数据读取进行了学习,介绍了常见的数据扩增方法和使用,使用Pytorch框架对本次赛题的数据进行了读取。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值