数据读取
pytorch内置了很多数据集,但是这次需要使用的是下载到本地的数据,这里就涉及到pytorch的自定义数据的读取了。pytorch提供了一个名为dataset的抽象类,个人感觉这个的形式更像一个接口了,重载getitem和len方法后,就可以定义一个自己的类了,使用这个类的好处是这样建立的数据集可以很好地被pytorch的其他类调用,比如Dataloader。而且在定义自己的数据类时还可以用transform方法对图片进行处理,这一部分就是涉及到下面的数据扩增了。
lass SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# 设置最长的字符长度为5个
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path)
初始化需要用三个参数。图片路径的list和图片label的list,tranform方法不是必须的。重载getitem方法即是用index进行访问。
数据扩增
torchvision.transforms是pytorch中的图像预处理包,包含了很多种对图像数据进行变换的函数。这里有大佬做了详细总结(传送门)
裁剪——Crop
中心裁剪:transforms.CenterCrop
随机裁剪:transforms.RandomCrop
随机长宽比裁剪:transforms.RandomResizedCrop
上下左右中心裁剪:transforms.FiveCrop
上下左右中心裁剪后翻转,transforms.TenCrop
翻转和旋转——Flip and Rotation
依概率p水平翻转:transforms.RandomHorizontalFlip(p=0.5)
依概率p垂直翻转:transforms.RandomVerticalFlip(p=0.5)
随机旋转:transforms.RandomRotation
图像变换
resize:transforms.Resize
标准化:transforms.Normalize
转为tensor,并归一化至[0-1]:transforms.ToTensor
填充:transforms.Pad
修改亮度、对比度和饱和度:transforms.ColorJitter
转灰度图:transforms.Grayscale
线性变换:transforms.LinearTransformation()
仿射变换:transforms.RandomAffine
依概率p转为灰度图:transforms.RandomGrayscale
将数据转换为PILImage:transforms.ToPILImage
transforms.Lambda:Apply a user-defined lambda as a transform.
对transforms操作,使数据增强更灵活
transforms.RandomChoice(transforms), 从给定的一系列transforms中选一个进行操作
transforms.RandomApply(transforms, p=0.5),给一个transform加上概率,依概率进行操作
transforms.RandomOrder,将transforms中的操作随机打乱