1.任务描述
数据读取,数据扩增,pytorch读取数据
2.图像读取
赛题数据:图像
赛题任务:识别图像字符
数据读取工具:pillow,opencv等python库
2.1pillow
2.2opencv
opencv是一个跨平台视觉库,功能强大。
3.数据扩增
数据扩增(Data Augmentation)
用途:增加训练集样本,缓解过拟合,增强泛化。
方法:颜色空间、尺度空间到样本空间,任务不同数据扩增方式也不同。
本次赛题:对字符进行识别,不能进行翻转。例如6翻转会变成9,改变字符原有含义。
torchvision中常见方法:
- transforms.CenterCrop对图片中心进行裁剪
- transforms.ColorJitter 对颜色对比度饱和度和零度变换
- transforms.FiveCrop对图像四个角和中心进行裁剪
- transforms.Grayscale 灰度变换
- transforms.Pad 像素填充
- transforms.RandomAffine仿射变换
- transforms.RandomCrop随机区域裁剪
- transforms.RandomHorizontalFlip随机水平翻转
- transforms.RandomRotation随机旋转
10.transforms.RandomVerticalFlip 随机垂直翻转
常用数据扩增库
torchvision,imgaug,albumentations
4.pytorch读取数据
pytorch中,数据通过Dataset封装,DataLoader读取。
1.Dataset:对数据集单个数据预处理然后封装。
2.DataLoader:dataset批量读取处理。
for file_name in glob.glob("*.jpg"):#获取指定目标下所有的jpg文件的文件名
5.代码
5.1 dataset对数据集封装
import os, sys, glob, shutil, json
import cv2
from PIL import Image
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# 原始SVHN中类别10为数字0
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path)
train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]
data = SVHNDataset(train_path, train_label,
transforms.Compose([
# 缩放到固定尺寸
transforms.Resize((64, 128)),
# 随机颜色变换
transforms.ColorJitter(0.2, 0.2, 0.2),
# 加入随机旋转
transforms.RandomRotation(5),
# 将图片转换为pytorch 的tesntor
# transforms.ToTensor(),
# 对图像像素进行归一化
# transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
]))
5.2 DataLoader批量处理数据
import os, sys, glob, shutil, json
import cv2
from PIL import Image
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# 原始SVHN中类别10为数字0
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path)
train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]
train_loader = torch.utils.data.DataLoader(
SVHNDataset(train_path, train_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])),
batch_size=10, # 每批样本个数
shuffle=False, # 是否打乱顺序
num_workers=10, # 读取的线程个数
)
for data in train_loader:
break
加入DataLoader之后,数据按批次获取,每批次调用Dataset读取封装单个样本。此时数据格式为
torch.Size([10, 3, 64, 128]), torch.Size([10, 6])
前者为图像文件batchsizechannelheight*weight,后者为字符标签每批10个,字符位数为6。