😎😎😎物体检测-系列教程 总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码
5、LoadImagesAndLabels类的__init__函数
5.1 定义、参数
class LoadImagesAndLabels(Dataset):
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
self.img_size = img_size
self.augment = augment
self.hyp = hyp
self.image_weights = image_weights
self.rect = False if image_weights else rect
self.mosaic = self.augment and not self.rect
self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride
self.path = path
- 继承自PyTorch的Dataset类
- 构造函数
- img_size ,输入图像的长、宽
- augment ,加载图像是否使用图像增强
- hyp ,超参数字典,包含数据增强、学习率等参数
- image_weights ,是否根据图像权重采样
- rect ,是否使用矩形训练
- mosaic ,当启用数据增强且不使用矩形训练时,此值为True。Mosaic数据增强会一次性加载4张图像,将它们组合成一个大的马赛克图像,训练时有助于模型学习到不同尺度的检测对象(马赛克4张拼成一张)
- mosaic_border ,定义在创建mosaic图像时使用的边界,通常取决于目标图像大小
- stride , 从输入到输出的降采样的比例
- path ,数据集路径
5.2 读取所有图像路径
try:
f = []
for p in path if isinstance(path, list) else [path]:
p = Path(p)
if p.is_dir():
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
elif p.is_file():
with open(p, 'r') as t:
t = t.read().strip().splitlines()
parent = str(p.parent) + os.sep
f += [x.replace('./', parent) if x.startswith('./') else x for x in t]
else:
raise Exception(f'{prefix}{p} does not exist')
self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
assert self.img_files, f'{prefix}No images found'
except Exception as e:
raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
- try
- f,用于存储所有图像的路径字符串
- 将数据集路径文件夹中的所有图像for循环遍历读取,如果path是一个list则遍历成list,如果不是将其转换成列表后遍历
- p,使用Path类将路径字符串转换成一个路径对象,Path是pathlib工具包的一个模块
- 如果p是一个路径
- 则使用glob模块递归地查找该目录及其所有子目录中的所有文件。
**
表示匹配所有目录,*.*
表示匹配所有文件 - 如果p是一个文件
- 则打开并读取它
- t,将内容按行分割,保存为一个list
- parent ,获取当前文件的父目录路径,str()函数将其转换为字符串。os.sep是当前操作系统的路径分隔符(在Windows上是\,在Unix/Linux上是/)
- f,按照windows格式或者linux格式,替换成绝对路径,并添加到f中
- 如果不是文件也不是路径
- 则抛出异常
- img_files :
- for循环遍历f的元素x
- 将x字符串按照.分割,然后取出.后面的元素即文件后缀名
- 如果这个后缀名在img_formats 这个list中的一个即是一个图像数据,则将/替换为os.sep
- 将所有的x封装到一个list中,然后将list进行排序,排序的结果就是img_files
- 其中img_formats = [‘bmp’, ‘jpg’, ‘jpeg’, ‘png’, ‘tif’, ‘tiff’, ‘dng’, ‘webp’, ‘mpo’]
- 确保文件列表不为空,如果为空抛出异常
- 如果在执行上述操作的过程中发生任何异常
- 则捕获这个异常并抛出一个新的异常
5.3 缓存
self.label_files = img2label_paths(self.img_files)
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
if cache_path.is_file():
cache, exists = torch.load(cache_path), True
if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache:
cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
else:
cache, exists = self.cache_labels(cache_path, prefix), False
nf, nm, ne, nc, n = cache.pop('results')
if exists:
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
tqdm(None, desc=prefix + d, total=n, initial=n)
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
cache.pop('hash')
cache.pop('version')
labels, shapes, self.segments = zip(*cache.values())
self.labels = list(labels)
self.shapes = np.array(shapes, dtype=np.float64)
self.img_files = list(cache.keys())
self.label_files = img2label_paths(cache.keys())
if single_cls:
for x in self.labels:
x[:, 0] = 0
n = len(shapes)
bi = np.floor(np.arange(n) / batch_size).astype(np.int)
nb = bi[-1] + 1
self.batch = bi
self.n = n
self.indices = range(n)
检查缓存:
- label_files ,通过图像文件路径使用 img2label_paths函数 获取相应的标签文件路径
- cache_path ,缓存文件的路径
- 如果path是一个文件,则直接用这个文件的路径(但后缀改为.cache)作为缓存路径
- 否则,取第一个标签文件的父目录,并创建或使用同目录下的.cache文件
- 如果缓存文件存在:
- cache, exists,加载缓存,一个存在的标记(布尔值)
展示缓存:
- 如果缓存的hash值与当前图像和标签文件列表的hash值不匹配,或者缓存版本信息缺失
- cache, exists,则使用cache_labels函数重新创建缓存,一个不存在的标记
- 如果缓存文件不存在:
- cache, exists,则使用cache_labels函数重新创建缓存,一个不存在的标记
- nf, nm, ne, nc, n,从缓存中取出图像标签文件的状态统计,包括找到的、缺失的、空的、损坏的标签文件数量以及总数
- 如果缓存已存在
- d,创建一个格式化的字符串d,用于描述缓存文件的内容
- 使用tqdm显示缓存扫描结果,tqdm是一个进度条工具,这有助于在加载大型数据集时提供反馈
- 确保至少找到了一个标签文件,或者没有启用数据增强,否则无法进行训练
读取缓存:
- 从缓存中移除hash信息
- 从缓存中移除version信息
- labels, shapes, segments,解包缓存中的值,分别为标签、图像shape值、分段信息
- labels,更新类实例的labels属性
- shapes,更新类实例的shapes属性
- img_files,更新类实例的img_files属性
- label_files,更新类实例的label_files属性
- 如果启用了单类模式(即检测的物体中,只有一个类别)
- 循环遍历所有标签
- 类别全部设置为0
- n,图像数量
- bi ,每个图像的批次索引
- nb ,总批次数量
- batch ,更新批次索引
- n,更新图像数量
- indices ,用于存储数据集中所有图像的索引
5.4 矩形训练
- 实现矩形训练以优化图像加载和处理
- 可选地将图像缓存到内存中以加速训练
if self.rect:
s = self.shapes
ar = s[:, 1] / s[:, 0]
irect = ar.argsort()
self.img_files = [self.img_files[i] for i in irect]
self.label_files = [self.label_files[i] for i in irect]
self.labels = [self.labels[i] for i in irect]
self.shapes = s[irect]
ar = ar[irect]
shapes = [[1, 1]] * nb
for i in range(nb):
ari = ar[bi == i]
mini, maxi = ari.min(), ari.max()
if maxi < 1:
shapes[i] = [maxi, 1]
elif mini > 1:
shapes[i] = [1, 1 / mini]
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
- 检查是否启用矩形训练,如果self.rect为True,则执行矩形训练的代码块
- s,存储数据集中每张图像的宽度和高度
- ar,计算每张图像的宽高比
- irect ,对宽高比进行升序排序
- img_files ,根据宽高比索引,重新排序图像文件路径列表,现在的img_files 的图像文件路径字符串全部按照宽高比从小到大进行排序
- label_files ,同样重新排序标签文件路径列表
- labels ,重新排序标签列表
- shapes ,重新排序图像形状数组
- ar,重新排序宽高比数组
- shapes ,初始化一个列表shapes,其中每个元素都是[1, 1],列表长度为批次数量nb
- 遍历每个批次,i是批次索引
- ari,从排序后的宽高比数组ar中选择当前批次i对应的图像的宽高比
- mini, maxi,计算当前批次图像宽高比的最小值和最大值
- 如果当前批次中最大宽高比小于1,表示所有图像都比较窄,则
- 设置当前批次的形状为[maxi, 1],以适应窄图像
- 如果当前批次中最小宽高比大于1,表示所有图像都比较宽
- 设置当前批次的形状为[1, 1 / mini],以适应宽图像
- batch_shapes ,计算每个批次的图像形状,并确保它们符合网络的步长要求。这里使用了向上取整、缩放、填充和类型转换,以生成最终的批次形状self.batch_shapes
5.5 缓存图像到内存
加载图像缓存到内存中,但是过大的数据集可能会超过系统内存
self.imgs = [None] * n
if cache_images:
gb = 0
self.img_hw0, self.img_hw = [None] * n, [None] * n
results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
pbar = tqdm(enumerate(results), total=n)
for i, x in pbar:
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x
gb += self.imgs[i].nbytes
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
pbar.close()
- imgs ,初始化图像缓存列表self.imgs,长度为n(图像总数),初始值都为None
- 如果启用了图像缓存,则执行以下缓存逻辑:
- gb,初始化一个变量,用于跟踪已缓存图像的总大小
- img_hw0, img_hw,初始化两个列表长度都为n(图像总数),用于存储每张图像的原始尺寸和调整后的尺寸。初始值都为None
- results ,使用线程池ThreadPool并行加载图像,以提高加载速度。这里创建了一个包含8个线程的线程池,并对每张图像调用
load_image
函数进行加载。imap
函数用于应用函数到输入序列的每个元素,这里的输入序列是通过zip和range(n)生成的,包含了每张图像的索引 - pbar ,使用tqdm创建一个进度条,以可视化图像加载和缓存的进度。enumerate(results)提供了一个带索引的结果序列,total=n指定了总进度条长度为图像总数n
- 遍历加载结果,i是图像的索引,x是加载结果(包括图像数据及其尺寸信息)
- 将加载的图像数据x分配给self.imgs[i],原始尺寸赋值给self.img_hw0[i],调整后的尺寸赋值给self.img_hw[i],这样,每张图像及其尺寸信息都被缓存起来
- gb ,将当前图像数据的字节大小(
self.imgs[i].nbytes
)加到gb变量上,以更新缓存的总大小 - 更新进度条的描述,显示当前已缓存图像的总大小(以GB为单位)。这提供了实时反馈,训练时知道缓存进度和内存使用情况
- 完成所有图像的加载和缓存后,关闭进度条。这是确保tqdm进度条正确结束并清理资源的标准做法
load_image
函数:
def load_image(self, index):
# loads 1 image from dataset, returns img, original hw, resized hw
img = self.imgs[index]
if img is None: # not cached
path = self.img_files[index]
img = cv2.imread(path) # BGR
assert img is not None, 'Image Not Found ' + path
h0, w0 = img.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # resize image to img_size
if r != 1: # always resize down, only resize up if training with augmentation
interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
else:
return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized