Datawhale零基础⼊⻔CV-Task2 数据读取与数据扩增

1.任务描述

数据读取,数据扩增,pytorch读取数据

2.图像读取

赛题数据:图像
赛题任务:识别图像字符
数据读取工具:pillow,opencv等python库

2.1pillow

在这里插入图片描述
在这里插入图片描述

2.2opencv

opencv是一个跨平台视觉库,功能强大。
在这里插入图片描述

3.数据扩增

数据扩增(Data Augmentation)
用途:增加训练集样本,缓解过拟合,增强泛化。
方法:颜色空间、尺度空间到样本空间,任务不同数据扩增方式也不同。
本次赛题:对字符进行识别,不能进行翻转。例如6翻转会变成9,改变字符原有含义。
torchvision中常见方法

  1. transforms.CenterCrop对图片中心进行裁剪
  2. transforms.ColorJitter 对颜色对比度饱和度和零度变换
  3. transforms.FiveCrop对图像四个角和中心进行裁剪
  4. transforms.Grayscale 灰度变换
  5. transforms.Pad 像素填充
  6. transforms.RandomAffine仿射变换
  7. transforms.RandomCrop随机区域裁剪
  8. transforms.RandomHorizontalFlip随机水平翻转
  9. transforms.RandomRotation随机旋转
    10.transforms.RandomVerticalFlip 随机垂直翻转
    常用数据扩增库
    torchvision,imgaug,albumentations

4.pytorch读取数据

pytorch中,数据通过Dataset封装,DataLoader读取。
1.Dataset:对数据集单个数据预处理然后封装。
2.DataLoader:dataset批量读取处理。
for file_name in glob.glob("*.jpg"):#获取指定目标下所有的jpg文件的文件名

5.代码

5.1 dataset对数据集封装

import os, sys, glob, shutil, json
import cv2

from PIL import Image
import numpy as np

import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 原始SVHN中类别10为数字0
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]

data = SVHNDataset(train_path, train_label,
          transforms.Compose([
              # 缩放到固定尺寸
              transforms.Resize((64, 128)),

              # 随机颜色变换
              transforms.ColorJitter(0.2, 0.2, 0.2),

              # 加入随机旋转
              transforms.RandomRotation(5),

              # 将图片转换为pytorch 的tesntor
              # transforms.ToTensor(),

              # 对图像像素进行归一化
              # transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ]))

5.2 DataLoader批量处理数据

import os, sys, glob, shutil, json
import cv2

from PIL import Image
import numpy as np

import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 原始SVHN中类别10为数字0
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]

train_loader = torch.utils.data.DataLoader(
        SVHNDataset(train_path, train_label,
                   transforms.Compose([
                       transforms.Resize((64, 128)),
                       transforms.ColorJitter(0.3, 0.3, 0.2),
                       transforms.RandomRotation(5),
                       transforms.ToTensor(),
                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])), 
    batch_size=10, # 每批样本个数
    shuffle=False, # 是否打乱顺序
    num_workers=10, # 读取的线程个数
)

for data in train_loader:
    break

加入DataLoader之后,数据按批次获取,每批次调用Dataset读取封装单个样本。此时数据格式为
torch.Size([10, 3, 64, 128]), torch.Size([10, 6])
前者为图像文件batchsizechannelheight*weight,后者为字符标签每批10个,字符位数为6。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值