YOLOV5源码解读系列文章目录
前言
此篇为yolov5 3.1 版本,官方地址[https://github.com/ultralytics/yolov5]
看源代码之前有必要先大致了解实现原理和流程,强推这篇文章https://blog.csdn.net/nan355655600/article/details/107852353(https://github.com/amdegroot/ssd.pytorch)
数据加载器由utils/datasets.py文件中的create_dataloader函数创建,其中主要有两个类构成LoadImagesAndLabels:数据集的加载和增强都由这个类实现 InfiniteDataLoader:对DataLoader进行封装,就是为了能够永久持续的采样数据,详细原因这里可以看官方说明[https://github.com/ultralytics/yolov5/pull/876](https://github.com/ultralytics/yolov5/pull/876)
持续采样InfiniteDataLoader
class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
""" Dataloader that reuses workers
Uses same syntax as vanilla DataLoader
"""
"""
这块对DataLoader进行封装,就是为了能够永久持续的采样数据
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class _RepeatSampler(object):
""" Sampler that repeats forever
永久持续的采样
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)
数据加载
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
"""
path: 数据集路径
img_size: 图片大小
batch_size: 批次大小
augment: 是否数据增强
hyp: 超参数的yaml文件
rect: 矩形训练,就是对图片填充灰边(只在高或宽的一边填充)
image_weights: 图像采样的权重
cache_images: 图片是否缓存,用于加速训练
single_cls: 是否是一个类别
stride: 模型步幅, 图像大小/网络下采样之后的输出大小
pad: 填充宽度
rank: 当前进程编号
"""
self.img_size = img_size
self.augment = augment
self.hyp = hyp
self.image_weights = image_weights
self.rect = False if image_weights else rect
# mosaic 将4张图片融合在一张图片里,进行训练
self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride
"""
首先读取图像路径,转换合适的格式,根据图像路径,替换其中的images和图片后缀,转换成label路径
读取coco128/labels/train.cache文件,没有则创建,cache存储字典{图片路径:label路径,图片大小}
"""
def img2label_paths(img_paths):
# Define label paths as a function of image paths
"""
img_paths现在存储了所有的图片路径,只需将路径中的images换成labels,图片后缀改为.txt就得到标注文件的路径
"""
sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
return [x.replace(sa, sb, 1).replace(os.path.splitext(x)[-1], '.txt') for x in img_paths]
# 读取图像路径,转换成合适的格式
try:
f = [] # image files
for p in path if isinstance(path, list) else [path]:
p = str(Path(p)) # os-agnostic
parent = str(Path(p).parent) + os.sep #上级目录 ../coco128/images
if os.path.isfile(p): # file
with open(p, 'r') as t:
t = t.read().splitlines()
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
elif os.path.isdir(p): # folder
f += glob.iglob(p + os.sep + '*.*') # 读取images下的所有文件不包含目录
else:
raise Exception('%s does not exist' % p)
# 将图片的路径改为适合本地系统的格式(windows是'\\', linux是'/'),图片后缀名在img_formats里的就改为小写
self.img_files = sorted(
[x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats])
assert len(self.img_files) > 0, 'No images found'
except Exception as e:
raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
# Check cache
self.label_files = img2label_paths(self.img_files) # labels 图片路径到label路径的转换
cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
"""
读取labels下的.cache文件, 没有则创建, cache里的关键字'hash'是图片+label的文件字节大小之和
"""
if os.path.isfile(cache_path):
cache = torch.load(cache_path) # load
# 如果cache存储的hash与当前的label+图片大小对应不上,则重新创建.cache文件
if cache['hash'] != get_hash(self.label_files + self.img_files): # dataset changed
cache = self.cache_labels(cache_path) # re-cache
else:
cache = self.cache_labels(cache_path) # cache
# Read cache
cache.pop('hash') # remove hash
labels, shapes = zip(*cache.values())
self.labels = list(labels) # label
self.shapes = np.array(shapes, dtype=np.float64) # 图片大小
self.img_files = list(cache.keys()) # update 图片路径
self.label_files = img2label_paths(cache.keys()) # update 更新labels路径,因为可能有一部分图片或label损坏
"""
根据图片数量划分每批的图片数量
"""
n = len(shapes) # number of images 图片数量
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index 划分批次
nb = bi[-1] + 1 # number of batches 批次数量
self.batch = bi # batch index of image
self.n = n
# Rectangular Training 矩形训练
"""
先求的图像的宽高比,然后对较长的边缩放到stride的倍数,
在按照宽高比对短的一边缩放,进行少量的填充也达到stride的最小倍数
"""
if self.rect:
# Sort by aspect ratio
s = self.shapes # wh
ar = s[:, 1] / s[:, 0] # aspect ratio 高宽比
irect = ar.argsort() # 按着高宽比从小到大排序
# 重新排序图片,label路径,真实框, shapes, 宽高比的顺序
self.img_files = [self.img_files[i] for i in irect]
self.label_files = [self.label_files[i] for i in irect]
self.labels = [self.labels[i] for i in irect]
self.shapes = s[irect] # wh
ar = ar[irect]
# Set training image shapes
shapes = [[1, 1]] * nb # [[h/w, 1], [1, w/h]....]