分类任务中数据增强可使用 torchvision 中的自带 transforms 进行数据增强
而分割任务中,需要保持img和mask进行相同的数据增强,可以使用 Albumentations库
Albumentations库
import albumentations as A
data_trans_train = A.Compose([A.Resize(height=512, width=512),
A.Flip(p=0.5),
A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=180, p=0.5,
border_mode=cv2.BORDER_CONSTANT),
A.HorizontalFlip(p=0.5),
ToTensorV2()
])
data_trans_val = A.Compose([A.Resize(height=512, width=512),
ToTensorV2()
])
train_data = data_seg(root_path='./od/train',flag='train',transforms=data_trans_train)
val_data = data_seg(root_path='./od/val',flag='val',transforms=data_trans_val)
train_dataloader = DataLoader(train_data,batch_size=batch_size)
val_dataloader = DataLoader(val_data,batch_size=batch_size)
data_seg中的部分如下:
def __getitem__(self, idx):
#transforms = self.transforms
#flag = self.flag
img_name = self.img_list[idx]
#root_path = self.root_path
if self.flag != 'test':
img_item_path = os.path.join(self.root_path,'img',img_name)
label_item_path = os.path.join(self.root_path,'label',img_name)
img = cv2.imread(img_item_path).copy() # [:, :, ::-1] BGR -> RGB
label = cv2.imread(label_item_path,cv2.IMREAD_GRAYSCALE).copy() # [:, :, ::-1] BGR -> RGB
transformed = self.transforms(image=img,mask=label)
img = transformed['image']
label = transformed['mask']
img = (img/255.0).to(torch.float32)
label = (label/255.0).to(torch.float32)
label[label >= 0.5] = 1.0
label[label < 0.5] = 0.0
label = torch.unsqueeze(label,dim=0)
return img,label
也可以使用torchvision方法