数据扩充
train_mask = pd.read_csv('train_mask.csv', sep='\t', names=['name', 'mask'])
img = cv2.imread('train/'+train_mask['name'].iloc[0])
mask = rle_decode(train_mask['mask'].iloc[0])
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(img)
plt.subplot(1, 2, 2)
plt.imshow(mask)
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(cv2.flip(img, 1))
plt.subplot(1, 2, 2)
plt.imshow(cv2.flip(mask, 1))
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(cv2.flip(img, 0))
plt.subplot(1, 2, 2)
plt.imshow(cv2.flip(mask, 0))
x, y = np.random.randint(0, 256), np.random.randint(0, 256)
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(img[x:x+256, y:y+256])
plt.subplot(1, 2, 2)
plt.imshow(mask[x:x+256, y:y+256])
import albumentations as A
augments = A.HorizontalFlip(p=1)(image=img, mask=mask)
img_aug, mask_aug = augments['image'], augments['mask']
plt.figure(figsize=(12, 8))
plt.subplot(2, 3, 1)
plt.imshow(img_aug)
plt.subplot(2, 3, 4)
plt.imshow(mask_aug)
augments = A.RandomCrop(p=1, height=256, width=256)(image=img, mask=mask)
img_aug, mask_aug = augments['image'], augments['mask']
plt.subplot(2, 3, 2)
plt.imshow(img_aug)
plt.subplot(2, 3, 5)
plt.imshow(mask_aug)
augments = A.ShiftScaleRotate(p=1)(image=img, mask=mask)
img_aug, mask_aug = augments['image'], augments['mask']
plt.subplot(2, 3, 3)
plt.imshow(img_aug)
plt.subplot(2, 3, 6)
plt.imshow(mask_aug)
trfm = A.Compose([
A.Resize(256, 256),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(),
])
augments = trfm(image=img, mask=mask)
img_aug, mask_aug = augments['image'], augments['mask']
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(augments['image'])
plt.subplot(1, 2, 2)
plt.imshow(augments['mask'])
定义数据集
import torch.utils.data as D
class TianChiDataset(D.Dataset):
def __init__(self, paths, rles, transform):
self.paths = paths
self.rles = rles
self.transform = transform
self.len = len(paths)
def __getitem__(self, index):
img = cv2.imread(self.paths[index])
mask = rle_decode(self.rles[index])
augments = self.transform(image=img, mask=mask)
return self.as_tensor(augments['image']), augments['mask'][None]
def __len__(self):
return self.len
dataset = TianChiDataset(
'train/'+train_mask['name'].values,
train_mask['mask'].fillna('').values,
trfm
)
loader = D.DataLoader(dataset, batch_size=10, shuffle=True)