## Dataset和Dataloader
### Dataset
Dataset是一个抽象类,实际使用中需要继承Dataset,并对其__len__()方法和__getitem__(idx)进行重构。前者为返回数据集长度,后者为查询idx所对应的img和其label。
### 数据增强对数据集的影响:
数据增强操作可以在Dataset中的getitem方法中实现。
class LeavesDataset(Dataset):
def __init__(self, csv, transform=None):
self.csv = csv
self.transform = transform
def __len__(self):
return len(self.csv['image'])
def __getitem__(self, idx):
img = cv2.imread(self.csv['image'][idx])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
label = labelmap[self.csv['label'][idx]]
if self.transform:
img = self.transform(image=img)['image'] # transform返回字典
return img, torch.tensor(label).type(torch.LongTensor) # img, label
假设数据集一共有100张图片,pytorch并非对数据集中的每张图片进行aug操作,将数据集扩增到200张,

最低0.47元/天 解锁文章
3840

被折叠的 条评论
为什么被折叠?



