数据集进行封装:
class SVHNDataset(Dataset): def __init__(self, img_path, img_label, transform=None): self.img_path = img_path self.img_label = img_label if transform is not None: self.transform = transform else: self.transform = None def __getitem__(self, index): img = Image.open(self.img_path[index]).convert('RGB') if self.transform is not None: img = self.transform(img) # 原始SVHN中类别10为数字0 lbl = np.array(self.img_label[index], dtype=np.int) lbl = list(lbl) + (5 - len(lbl)) * [10] return img, torch.from_numpy(np.array(lbl[:5])) def __len__(self): return len(self.img_path)
通过封装该类可以通过索引向访问数组一样对数据集进行访问
数据增强:
data = SVHNDataset(train_path, train_label, transforms.Compose([ # 缩放到固定尺寸 transforms.Resize((64, 128)), # 随机颜色变换 transforms.ColorJitter(0.2, 0.2, 0.2), # 加入随机旋转 transforms.RandomRotation(5), # 将图片转换为pytorch 的tesntor # transforms.ToTensor(), # 对图像像素进行归一化 # transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) ]))
通过上述代码,可以将赛题的图像数据和对应标签进行读取,在读取过程中的进行数据扩增
加载数据集:
通过pytorch自带的DateLoader可多线程将数据集子自动按照batch_size大小进行划分,返回迭代器。
train_loader = torch.utils.data.DataLoader( SVHNDataset(train_path, train_label, transforms.Compose([ transforms.Resize((64, 128)), transforms.ColorJitter(0.3, 0.3, 0.2), transforms.RandomRotation(5), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])), batch_size=10, # 每批样本个数 shuffle=False, # 是否打乱顺序 num_workers=10, # 读取的线程个数 )