import albumentations as A
import torch
import math
import random
import os
import cv2
import shutil
import numpy as np
import argparse
from torchvision import transforms
from PIL import Image
def make_odd(num):
num = math.ceil(num)
if num % 2 == 0:
num += 1
return num
def med_augment(data_path, name, level, number_branch, mask_i=False, shield=False):
if mask_i:
image_path = f"{data_path}{name}"
mask_path = f"{image_path}_mask"
output_path = f"{os.path.dirname(os.path.dirname(data_path))}/medaugment/{name}/"
out_mask = f"{os.path.dirname(os.path.dirname(data_path))}/medaugment/{name}_mask/"
else:
image_path = data_path + name
output_path = f"{os.path.dirname(os.path.dirname(os.path.dirname(data_path)))}/medaugment/training/{name}/"
transform = A.Compose([
A.ColorJitter(brightness=0.04 * level, contrast=0, saturation=0, hue=0, p=0.2 * level),
A.ColorJitter(brightness=0, contrast=0.04 * level, saturation=0, hue=0, p=0.2 * level),
A.Posterize(num_bits=math.floor(8 - 0.8 * level), p=0.2 * level),
A.Sharpen(alpha=(0.04 * level, 0.1 * level), lightness=(1, 1), p=0.2 * level),
A.GaussianBlur(blur_limit=(3, make_odd(3 + 0.8 * level)), p=0.2 * level),
A.GaussNoise(var_limit=(2 * level, 10 * level), mean=0, per_channel=True, p=0.2 * level),
A.Rotate(limit=4 * level, interpolation=1, border_mode=0, value=0, mask_value=None, rotate_method='largest_box',
crop_border=False, p=0.2 * level),
A.HorizontalFlip(p=0.2 * level),
A.VerticalFlip(p=0.2 * level),
A.Affine(scale=(1 - 0.04 * level, 1 + 0.04 * level), translate_percent=None, translate_px=None, rotate=None,
shear=None, interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,
keep_ratio=True, p=0.2 * level),
A.Affine(scale=None, translate_percent=None, translate_px=None, rotate=None,
shear={'x': (0, 2 * level), 'y': (0, 0)}
, interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,
keep_ratio=True, p=0.2 * level), # x
A.Affine(scale=None, translate_percent=None, translate_px=None, rotate=None,
shear={'x': (0, 0), 'y': (0, 2 * level)}
, interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,
keep_ratio=True, p=0.2 * level),
A.Affine(scale=None, translate_percent={'x': (0, 0.02 * level), 'y': (0, 0)}, translate_px=None, rotate=None,
shear=None, interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,
keep_ratio=True, p=0.2 * level),
A.Affine(scale=None, translate_percent={'x': (0, 0), 'y': (0, 0.02 * level)}, translate_px=None, rotate=None,
shear=None, interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,
keep_ratio=True, p=0.2 * level)
])
for j, file_name in enumerate(os.listdir(image_path)):
if file_name.endswith(".png") or file_name.endswith(".jpg"):
file_path = os.path.join(image_path, file_name)
file_n, file_s = file_name.split(".")[0], file_name.split(".")[1]
image = cv2.imread(file_path)
if mask_i: mask = cv2.imread(f"{mask_path}/{file_n}_mask.{file_s}")
strategy = [(1, 2), (0, 3), (0, 2), (1, 1)]
for i in range(number_branch):
if number_branch != 4:
employ = random.choice(strategy)
else:
index = random.randrange(len(strategy))
employ = strategy.pop(index)
level, shape = random.sample(transform[:6], employ[0]), random.sample(transform[6:], employ[1])
img_transform = A.Compose([*level, *shape])
random.shuffle(img_transform.transforms)
if not os.path.exists(output_path): os.makedirs(output_path)
if mask_i:
transformed = img_transform(image=image, mask=mask)
transformed_image, transformed_mask = transformed['image'], transformed['mask']
cv2.imwrite(f"{output_path}/{file_n}_{i+1}.{file_s}", transformed_image)
cv2.imwrite(f"{out_mask}/{file_n}_{i+1}_mask.{file_s}", transformed_mask)
else:
transformed = img_transform(image=image)
transformed_image = transformed['image']
cv2.imwrite(f"{output_path}/{file_n}_{i+1}.{file_s}", transformed_image)
if not shield:
cv2.imwrite(f"{output_path}/{file_n}_{number_branch+1}.{file_s}", image)
if mask_i: cv2.imwrite(f"{out_mask}/{file_n}_{number_branch+1}_mask.{file_s}", mask)
def generate_datasets(train_type, dataset, seed, level, number_branch):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
if train_type == "classification":
print('Executing data augmentation for image classification...')
data_path = f"./datasets/classification/{dataset}/baseline/training/"
folder_path = f"./datasets/classification/{dataset}/"
n = len([name for name in os.listdir(f"{folder_path}/baseline/training") if
os.path.isdir(os.path.join(f"{folder_path}/baseline/training", name))])
for folder in ["medaugment"]:
shutil.copytree(f"{folder_path}baseline", f"{folder_path}{folder}",
ignore=shutil.ignore_patterns("training"))
training_folder_path = f"{folder_path}{folder}/training"
os.makedirs(training_folder_path)
for i in range(n):
os.makedirs(f"{training_folder_path}/n{i}")
for i in range(n):
name = f"n{i}"
med_augment(data_path, name, level, number_branch)
else:
print('Executing data augmentation for image segmentation...')
data_path = f"./datasets/segmentation/{dataset}/baseline/"
folder_path = f"./datasets/segmentation/{dataset}/"
for folder in ["medaugment"]:
shutil.copytree(f"{folder_path}baseline", f"{folder_path}{folder}",
ignore=shutil.ignore_patterns("training", "training_mask"))
os.makedirs(f"{folder_path}{folder}/training")
os.makedirs(f"{folder_path}{folder}/training_mask")
folder_list = ["training"]
for i in range(len(folder_list)):
name = folder_list[i]
med_augment(data_path, name, level, number_branch, mask_i=True)
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter)
group = parser.add_argument_group()
group.add_argument('--dataset', required=True)
group.add_argument('--train_type', choices=['classification', 'segmentation'], default='classification')
group.add_argument('--level', help='Augmentation level', default=5, type=int, metavar='INT')
group.add_argument('--number_branch', help='Number of branch', default=4, type=int, metavar='INT')
group.add_argument('--seed', help='Seed', default=8, type=int, metavar='INT')
args = parser.parse_args()
generate_datasets(**vars(args))
if __name__ == '__main__':
main()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
- 85.
- 86.
- 87.
- 88.
- 89.
- 90.
- 91.
- 92.
- 93.
- 94.
- 95.
- 96.
- 97.
- 98.
- 99.
- 100.
- 101.
- 102.
- 103.
- 104.
- 105.
- 106.
- 107.
- 108.
- 109.
- 110.
- 111.
- 112.
- 113.
- 114.
- 115.
- 116.
- 117.
- 118.
- 119.
- 120.
- 121.
- 122.
- 123.
- 124.
- 125.
- 126.
- 127.
- 128.
- 129.
- 130.
- 131.
- 132.
- 133.
- 134.
- 135.
- 136.
- 137.
- 138.
- 139.
- 140.
- 141.
- 142.
- 143.
- 144.
- 145.
- 146.
- 147.