Datawhale 零基础入门CV赛事-Task2 数据读取与数据扩增

数据读取

导入需要的包以及文件路径

import json, glob
import numpy as np
from PIL import Image
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

train_path = glob.glob("../mchar_train/*.png")
label_path = "../mchar_train.json"
train_json = json.load(open(label_path))
train_label = [train_json[x]['label'] for x in train_json]

图像读取

1.pillow

Pillow是Python图像处理函式库(PIL)的一个分支。Pillow提供了常见的图像读取和处理的操作,而且可以与ipython notebook无缝集成,是应用比较广泛的库。pillow官方文档

2.opencv

OpenCV是一个跨平台的计算机视觉库,最早由Intel开源得来。OpenCV发展的非常早,拥有众多的计算机视觉、数字图像处理和机器视觉等功能。OpenCV在功能上比Pillow更加强大很多,学习成本也高很多。opencv官网opencv GitHubOpenCV 扩展算法库

数据读取

pytorch有Dataset类,自己定义的Dataset都要继承这个类。参考Pytorch中文文档。这个类必须有getitem函数用来索引数据和len函数用来判断大小。transform用来对图像做变换

class SVHNDataset(Dataset):

    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label
        if transform is None:
            self.transform = None
        else:
            self.transform = transform

    def __getitem__(self, item):
        img = Image.open(self.img_path[item]).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)

        label = self.img_label[item]
        label = list(label) + [10]*(5-len(label))

        return img, torch.from_numpy(np.array(label[:5]))

    def __len__(self):
        return len(self.img_path)

数据扩增

在深度学习模型的训练过程中,数据扩增是必不可少的环节。现有深度学习的参数非常多,一般的模型可训练的参数量基本上都是万到百万级别,而训练集样本的数量很难有这么多。

其次数据扩增可以扩展样本空间,假设现在的分类模型需要对汽车进行分类,左边的是汽车A,右边为汽车B。如果不使用任何数据扩增方法,深度学习模型会从汽车车头的角度来进行判别,而不是汽车具体的区别。

在常见的数据扩增方法中,一般会从图像颜色、尺寸、形态、空间和像素等角度进行变换。当然不同的数据扩增方法可以自由进行组合,得到更加丰富的数据扩增方法。

以torchvision为例,常见的数据扩增方法包括:

  • transforms.CenterCrop 对图片中心进行裁剪
  • transforms.ColorJitter 对图像颜色的对比度、饱和度和零度进行变换
  • transforms.FiveCrop 对图像四个角和中心进行裁剪得到五分图像
  • transforms.Grayscale 对图像进行灰度变换
  • transforms.Pad 使用固定值进行像素填充
  • transforms.RandomAffine 随机仿射变换
  • transforms.RandomCrop 随机区域裁剪
  • transforms.RandomHorizontalFlip 随机水平翻转
  • transforms.RandomRotation 随机旋转
  • transforms.RandomVerticalFlip 随机垂直翻转

在Pytorch中,数据扩增在Dataloader中实现。Dataloader是数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。

train_loader = torch.utils.data.DataLoader(
        SVHNDataset(train_path, train_label,
                   transforms.Compose([
                       transforms.Resize((64, 128)),
                       transforms.ColorJitter(0.3, 0.3, 0.2),
                       transforms.RandomRotation(degrees=5),
                       transforms.ToTensor(),
                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])),
    batch_size=10, # 每批样本个数
    shuffle=False, # 是否打乱顺序
    num_workers=0, # 读取的线程个数
)

以上为数据读取方式,本博客内容均来源于DataWhale

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值