今天阅读其他大佬的代码,发现了一种可以同时对image和mask做数据增强的方法,记录一下。
from torchvision.transforms import functional as F
from torchvision import transforms as tfs
from PIL import Image
import matplotlib.pyplot as plt
def rand_crop(image, label, height=300, width=300):
'''
data is PIL.Image object
label is PIL.Image object
'''
crop_params = tfs.RandomCrop.get_params(image, (height, width))
image = F.crop(image, *crop_params)
label = F.crop(label, *crop_params)
return image, label
img = Image.open("2007_000068.jpg")
lab = Image.open("2007_000068.png")
image, label = rand_crop(img, lab)
plt.figure()
plt.subplot(1, 2, 1)
plt.title("origin")
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("gray")
plt.imshow(label, cmap="gray")
plt.show()
总结:
同时对label和mask做随机裁剪,需要先获取裁剪信息,然后再同时对其做裁剪,其他增强方法或许也能参考此方法。