YOLOv5添加自定义数据增广方法
虽然YOLOv5内置数据增广方法非常丰富,包括随机旋转、翻转、HSV-Saturation等。但仍然有添加自定义的数据增广方法的情况。例如使用N+L策略训练网络。N+L表示Normal resolution和Low resolution混合数据集。这时就需要在YOLOv5中添加退化算法。那么这里就以一种用于盲超分模型的退化算法[1]为例,以下称为B-DEGRADE,源代码来源于BSRGAN项目,展示一哈如何在YOLOv5项目中添加自定义数据增广方法。
YOLOv5的Dataloader与Dataset
YOLOv5 在/yolov5/utils/dataset.py中创建Dataloader与Dataset实例。like below。Dataloader使用直接使用pytorch Dataloader类或者基于Dataloader创建子类。Dataloader负责在训练和验证时产生batch迭代器。该类的一个重要成员变量就是Dataset。Dataset具有加载数据、按索引获取数据等功能。数据增广Dataset获取数据时完成!
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False):
if rect and shuffle:
LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
shuffle = False
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
# 创建Dataset实例
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augmentation
hyp=hyp, # hyperparameters
rect=rect, # rectangular batches
cache_images=cache,
single_cls=single_cls,
stride=int(stride),
pad=pad,
image_weights=image_weights,
prefix=prefix)
batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
# 使用pytorch自带的Dataloader或者使用基于Dotaloader创建的子类实例
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
return loader(dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=nw,
sampler=sampler,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
YOLOv5 Dataset是LoadImagesAndLabels
类的实例。LoadImagesAndLabels
类在调用构造函数时完成数据集加载,可以在__init__
方法中看到加载图像位置与标签位置的代码,like below。cache_labels
方法读取标签文件生成标签矩阵。在cache_labels
方法中调用了一个需要特别注意的方法verify_image_label
,这个方法用来验证某个图像是否具有标签,(Attetion: 特别注意这里,后面要用到哦!)如果图像没有标签,那么该图像索引对应的标签被赋值为np.zero((0,5), dtype=np.float32)
def __init__(self,...):
...
try:
# 读取图像文件位置
f = [] # image files
for p in path if isinstance(path, list) else [path]:
p = Path(p) # os-agnostic
if p.is_dir(): # dir
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
# f = list(p.rglob('*.*')) # pathlib
elif p.is_file(): # file
with open(p) as t:
t = t.read().strip().splitlines()
parent = str(p.parent) + os.sep
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
else:
raise Exception(f'{prefix}{p} does not exist')
self.img_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
assert self.img_files, f'{prefix}No images found'
except Exception as e:
raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
# 查看是否有标签数据cache文件,如果没有则读取使用cache_labels方法读取label文件
# Check cache
# 根据图像文件位置生成标签文件位置
self.label_files = img2label_paths(self.img_files) # labels
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
try:
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
assert cache['version'] == self.cache_version # same version
assert cache['hash'] == get_hash(self.label_files + self.img_files) # same hash
except Exception:
cache, exists = self.cache_labels(cache_path, prefix), False # cache
def verify_image_label(args):
# Verify one image-label pair
im_file, lb_file, prefix = args
nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] # number (missing, found, empty, corrupt), message, segments
try:
# verify images
im = Image.open(im_file)
im.verify() # PIL verify
shape = exif_size(im) # image size
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
if im.format.lower() in ('jpg', 'jpeg'):
with open(im_file, 'rb') as f:
f.seek(-2, 2)
if f.read() != b'\xff\xd9': # corrupt JPEG
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
msg = f'{prefix}WARNING: {im_file}: corrupt JPEG restored and saved'
# verify labels
if os.path.isfile(lb_file):
nf = 1 # label found
with open(lb_file) as f:
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
if any([len(x) > 8 for x in lb]): # is segment
classes = np.array([x[0] for x in lb], dtype=np.float32)
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
lb = np.array(lb, dtype=np.float32)
nl = len(lb)
if nl:
assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
_, i = np.unique(lb, axis=0, return_index=True)
if len(i) < nl: # duplicate row check
lb = lb[i] # remove duplicates
if segments:
segments = segments[i]
msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
else:
# 图像不存在label时,将label定义为如下形式
ne = 1 # label empty
lb = np.zeros((0, 5), dtype=np.float32)
else:
# label缺失时,将label定义为如下形式
nm = 1 # label missing
lb = np.zeros((0, 5), dtype=np.float32)
return im_file, lb, shape, segments, nm, nf, ne, nc, msg
except Exception as e:
nc = 1
msg = f'{prefix}WARNING: {im_file}: ignoring corrupt image/label: {e}'
return [None, None, None, None, nm, nf, ne, nc, msg]
YOLOv5的数据增广方法
LoadImagesAndLabels
类在按索引获取数据时完成数据增广,方法为__getitem__(self, index)
。YOLOv5内置的数据增广方式包括Mosaic、随机旋转、随机翻转、HSV等。Mosaic默认打开。因此我们可以直接进入load_mosaic
中添加自己自定义的数据增广方法。
def __getitem__(self, index):
index = self.indices[index] # linear, shuffled, or image_weights
hyp = self.hyp
mosaic = self.mosaic and random.random() < hyp['mosaic']
if mosaic: # mosaic默认打开,所以直接进入load_mosaic方法即可
# Load mosaic
img, labels = self.load_mosaic(index)
shapes = None
...
def load_mosaic(self, index):
# YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
labels4, segments4 = [], []
s = self.img_size
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
# 随机产生另外3个图片,与指定图片拼接在一起
indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
random.shuffle(indices)
for i, index in enumerate(indices):
# Load image
img, _, (h, w) = self.load_image(index)
# place img in img4
if i == 0: # top left
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
elif i == 1: # top right
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
elif i == 2: # bottom left
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
elif i == 3: # bottom right
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
padw = x1a - x1b
padh = y1a - y1b
# Labels
labels, segments = self.labels[index].copy(), self.segments[index].copy()
if labels.size:
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
labels4.append(labels)
segments4.extend(segments)
# Concat/clip labels
# mosaic后,标签数值也发生相应变化
labels4 = np.concatenate(labels4, 0)
for x in (labels4[:, 1:], *segments4):
np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
# img4, labels4 = replicate(img4, labels4) # replicate
# Augment
# 对拼接后的图像进行数据增广
img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
img4, labels4 = random_perspective(img4, labels4, segments4,
degrees=self.hyp['degrees'],
translate=self.hyp['translate'],
scale=self.hyp['scale'],
shear=self.hyp['shear'],
perspective=self.hyp['perspective'],
border=self.mosaic_border) # border to remove
return img4, labels4
将自己的增广方法添加到YOLOv5中
Now,我们尝试将BSRGAN中自带的退化方法B-DEGRADE为添加到YOLOv5中。目标是在将batch投入训练前,按照一定概率(设当
P
<
0.3
P<0.3
P<0.3时)使图片退化。then,Datasetloader
怎么获取一个batch的?根据上面的代码分析,当然是通过Dataset
实例的__getitem__
方法啦。那么我们只需要在getitem读取完图片后,马上退化图片即可儿。因为getitem通过调用self.load_mosaic
加载图片,那么我们就在load_mosaic中添加自己的退化方法!所以we have
# 首先记得import我们的增广方法b_degrade
from BSRGAN import utils.utils_blindsr.degradation_bsrgan as b_degrade
...
def __getitem__(...):
...
img, labels = self.load_mosaic(index)
...
def load_mosaic(...):
labels4, segments4 = [], []
s = self.img_size
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
# 随机产生另外3个图片,与指定图片拼接在一起
indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
random.shuffle(indices)
for i, index in enumerate(indices):
# Load image
img, _, (h, w) = self.load_image(index)
### Start: 添加我们的增广方法b_degrade ###
degrade_flag = False
if self.augment:
if random.random() < 0.3:
img, _ = b_degrade(img, sf=2) # 将退化算法中下采样率设置为2
degrade_flag = True
n, m, _ = img.shape
labels = self.labels[index].copy() # 提前加载labels
# 由于退化中含有下采样,应当剔除退化后无法识别的目标。
if degrade_flag:
labels = clear_cant_detect_object(labels, n, m, threshold=783) # 按照设定的阈值剔除无法识别的目标儿
### End: 添加我们的增广方法b_degrade ###
...
参考文献
[1]. Zhang, Kai, et al. “Designing a practical degradation model for deep blind image super-resolution.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021.