通过调整亮度、大小等,分别扩充样本和对应label的数量
(需安装torchvision)
import torchvision.transforms as transforms
from PIL import Image
import torchvision
pic_read = '样本文件地址'
label_read = '对应label地址'
pic_save = '样本扩充保存地址'
label_save = 'label扩充保存地址'
# def voc_rand_crop(feature, label, height, wide):
# rect = torchvision.transforms.RandomCrop.get_params(feature, (height, wide))
# feature = torchvision.transforms.functional.crop(feature, *rect)
# # print(type(feature))
# label = torchvision.transforms.functional.crop(label, *rect)
# return feature, label
def main():
transform = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
transform2 = transforms.Grayscale(3)
for i in range(1):
img = Image.open(pic_read + f'{i}'+'.png')
img2 = Image.open(label_read + f'{i}'+'.png')
for i in range(5000):#扩充的数量
imgt = transform(img)
imgt2 = transform2(imgt)
rect = torchvision.transforms.RandomCrop.get_params(imgt2, (512, 512))
feature = torchvision.transforms.functional.crop(imgt2, *rect)
label = torchvision.transforms.functional.crop(img2, *rec)
feature.save(pic_save + f'{i}' + '.png')
label.save(label_save + f'{i}' + '.png')
if __name__ == "__main__":
main()