学习目标
- 学习Python和Pytorch中图像读取
- 学会扩增方法和Pytorch读取塞梯数据
图像读取
常见的库:
Pillow (PIL)
OpenCV
Pytorch读取数据
import os, sys, glob, shutil, json
import cv2
from PIL import Image
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# 原始SVHN中类别10为数字0
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path)
train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]
data = SVHNDataset(train_path, train_label,
transforms.Compose([
# 缩放到固定尺寸
transforms.Resize((64, 128)),
# 随机颜色变换
transforms.ColorJitter(0.2, 0.2, 0.2),
# 加入随机旋转
transforms.RandomRotation(5),
# 将图片转换为pytorch 的tesntor
# transforms.ToTensor(),
# 对图像像素进行归一化
# transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
]))
图像扩增
目的:增加训练样本。
常见的数据扩增方法:
颜色空间,尺度空间,样本空间。
对于图像分类,数据扩增一般不会改变标签;对于物体检测,数据扩增改变物体坐标位置;对于图像分割,数据扩增改变像素标签。
数据扩增库:
torchvision
imgaug
albumentations
常见方法: