物体检测-系列教程13:YOLOV5 源码解析3 (LoadImagesAndLabels类:构造函数)

😎😎😎物体检测-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

5、LoadImagesAndLabels类的__init__函数

5.1 定义、参数

class LoadImagesAndLabels(Dataset):
    def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
                 cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
        self.img_size = img_size
        self.augment = augment
        self.hyp = hyp
        self.image_weights = image_weights
        self.rect = False if image_weights else rect
        self.mosaic = self.augment and not self.rect
        self.mosaic_border = [-img_size // 2, -img_size // 2]
        self.stride = stride
        self.path = path
  1. 继承自PyTorch的Dataset类
  2. 构造函数
  3. img_size ,输入图像的长、宽
  4. augment ,加载图像是否使用图像增强
  5. hyp ,超参数字典,包含数据增强、学习率等参数
  6. image_weights ,是否根据图像权重采样
  7. rect ,是否使用矩形训练
  8. mosaic ,当启用数据增强且不使用矩形训练时,此值为True。Mosaic数据增强会一次性加载4张图像,将它们组合成一个大的马赛克图像,训练时有助于模型学习到不同尺度的检测对象(马赛克4张拼成一张)
  9. mosaic_border ,定义在创建mosaic图像时使用的边界,通常取决于目标图像大小
  10. stride , 从输入到输出的降采样的比例
  11. path ,数据集路径

5.2 读取所有图像路径

        try:
            f = []
            for p in path if isinstance(path, list) else [path]:
                p = Path(p)
                if p.is_dir():
                    f += glob.glob(str(p / '**' / '*.*'), recursive=True)
                elif p.is_file():
                    with open(p, 'r') as t:
                        t = t.read().strip().splitlines()
                        parent = str(p.parent) + os.sep
                        f += [x.replace('./', parent) if x.startswith('./') else x for x in t]
                else:
                    raise Exception(f'{prefix}{p} does not exist')
            self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
            assert self.img_files, f'{prefix}No images found'
        except Exception as e:
            raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
  1. try
  2. f,用于存储所有图像的路径字符串
  3. 将数据集路径文件夹中的所有图像for循环遍历读取,如果path是一个list则遍历成list,如果不是将其转换成列表后遍历
  4. p,使用Path类将路径字符串转换成一个路径对象,Path是pathlib工具包的一个模块
  5. 如果p是一个路径
  6. 则使用glob模块递归地查找该目录及其所有子目录中的所有文件。**表示匹配所有目录,*.*表示匹配所有文件
  7. 如果p是一个文件
  8. 则打开并读取它
  9. t,将内容按行分割,保存为一个list
  10. parent ,获取当前文件的父目录路径,str()函数将其转换为字符串。os.sep是当前操作系统的路径分隔符(在Windows上是\,在Unix/Linux上是/)
  11. f,按照windows格式或者linux格式,替换成绝对路径,并添加到f中
  12. 如果不是文件也不是路径
  13. 则抛出异常
  14. img_files :
    • for循环遍历f的元素x
    • 将x字符串按照.分割,然后取出.后面的元素即文件后缀名
    • 如果这个后缀名在img_formats 这个list中的一个即是一个图像数据,则将/替换为os.sep
    • 将所有的x封装到一个list中,然后将list进行排序,排序的结果就是img_files
    • 其中img_formats = [‘bmp’, ‘jpg’, ‘jpeg’, ‘png’, ‘tif’, ‘tiff’, ‘dng’, ‘webp’, ‘mpo’]
  15. 确保文件列表不为空,如果为空抛出异常
  16. 如果在执行上述操作的过程中发生任何异常
  17. 则捕获这个异常并抛出一个新的异常

5.3 缓存

        self.label_files = img2label_paths(self.img_files)
        cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
        if cache_path.is_file():
            cache, exists = torch.load(cache_path), True
            if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache:  
                cache, exists = self.cache_labels(cache_path, prefix), False  # re-cache
        else:
            cache, exists = self.cache_labels(cache_path, prefix), False
        nf, nm, ne, nc, n = cache.pop('results')
        if exists:
            d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
            tqdm(None, desc=prefix + d, total=n, initial=n)
        assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
        cache.pop('hash')
        cache.pop('version')
        labels, shapes, self.segments = zip(*cache.values())
        self.labels = list(labels)
        self.shapes = np.array(shapes, dtype=np.float64)
        self.img_files = list(cache.keys())
        self.label_files = img2label_paths(cache.keys())
        if single_cls:
            for x in self.labels:
                x[:, 0] = 0
        n = len(shapes)
        bi = np.floor(np.arange(n) / batch_size).astype(np.int)
        nb = bi[-1] + 1
        self.batch = bi
        self.n = n
        self.indices = range(n)

检查缓存:

  1. label_files ,通过图像文件路径使用 img2label_paths函数 获取相应的标签文件路径
  2. cache_path ,缓存文件的路径
    • 如果path是一个文件,则直接用这个文件的路径(但后缀改为.cache)作为缓存路径
    • 否则,取第一个标签文件的父目录,并创建或使用同目录下的.cache文件
  3. 如果缓存文件存在:
  4. cache, exists,加载缓存,一个存在的标记(布尔值)

展示缓存:

  1. 如果缓存的hash值与当前图像和标签文件列表的hash值不匹配,或者缓存版本信息缺失
  2. cache, exists,则使用cache_labels函数重新创建缓存,一个不存在的标记
  3. 如果缓存文件不存在:
  4. cache, exists,则使用cache_labels函数重新创建缓存,一个不存在的标记
  5. nf, nm, ne, nc, n,从缓存中取出图像标签文件的状态统计,包括找到的、缺失的、空的、损坏的标签文件数量以及总数
  6. 如果缓存已存在
  7. d,创建一个格式化的字符串d,用于描述缓存文件的内容
  8. 使用tqdm显示缓存扫描结果,tqdm是一个进度条工具,这有助于在加载大型数据集时提供反馈
  9. 确保至少找到了一个标签文件,或者没有启用数据增强,否则无法进行训练

读取缓存:

  1. 从缓存中移除hash信息
  2. 从缓存中移除version信息
  3. labels, shapes, segments,解包缓存中的值,分别为标签、图像shape值、分段信息
  4. labels,更新类实例的labels属性
  5. shapes,更新类实例的shapes属性
  6. img_files,更新类实例的img_files属性
  7. label_files,更新类实例的label_files属性
  8. 如果启用了单类模式(即检测的物体中,只有一个类别)
  9. 循环遍历所有标签
  10. 类别全部设置为0
  11. n,图像数量
  12. bi ,每个图像的批次索引
  13. nb ,总批次数量
  14. batch ,更新批次索引
  15. n,更新图像数量
  16. indices ,用于存储数据集中所有图像的索引

5.4 矩形训练

  • 实现矩形训练以优化图像加载和处理
  • 可选地将图像缓存到内存中以加速训练
        if self.rect:
            s = self.shapes
            ar = s[:, 1] / s[:, 0]
            irect = ar.argsort()
            self.img_files = [self.img_files[i] for i in irect]
            self.label_files = [self.label_files[i] for i in irect]
            self.labels = [self.labels[i] for i in irect]
            self.shapes = s[irect]
            ar = ar[irect]
            shapes = [[1, 1]] * nb
            for i in range(nb):
                ari = ar[bi == i]
                mini, maxi = ari.min(), ari.max()
                if maxi < 1:
                    shapes[i] = [maxi, 1]
                elif mini > 1:
                    shapes[i] = [1, 1 / mini]
            self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride  
  1. 检查是否启用矩形训练,如果self.rect为True,则执行矩形训练的代码块
  2. s,存储数据集中每张图像的宽度和高度
  3. ar,计算每张图像的宽高比
  4. irect ,对宽高比进行升序排序
  5. img_files ,根据宽高比索引,重新排序图像文件路径列表,现在的img_files 的图像文件路径字符串全部按照宽高比从小到大进行排序
  6. label_files ,同样重新排序标签文件路径列表
  7. labels ,重新排序标签列表
  8. shapes ,重新排序图像形状数组
  9. ar,重新排序宽高比数组
  10. shapes ,初始化一个列表shapes,其中每个元素都是[1, 1],列表长度为批次数量nb
  11. 遍历每个批次,i是批次索引
  12. ari,从排序后的宽高比数组ar中选择当前批次i对应的图像的宽高比
  13. mini, maxi,计算当前批次图像宽高比的最小值和最大值
  14. 如果当前批次中最大宽高比小于1,表示所有图像都比较窄,则
  15. 设置当前批次的形状为[maxi, 1],以适应窄图像
  16. 如果当前批次中最小宽高比大于1,表示所有图像都比较宽
  17. 设置当前批次的形状为[1, 1 / mini],以适应宽图像
  18. batch_shapes ,计算每个批次的图像形状,并确保它们符合网络的步长要求。这里使用了向上取整、缩放、填充和类型转换,以生成最终的批次形状self.batch_shapes

5.5 缓存图像到内存

加载图像缓存到内存中,但是过大的数据集可能会超过系统内存

        self.imgs = [None] * n
        if cache_images:
            gb = 0
            self.img_hw0, self.img_hw = [None] * n, [None] * n
            results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
            pbar = tqdm(enumerate(results), total=n)
            for i, x in pbar:
                self.imgs[i], self.img_hw0[i], self.img_hw[i] = x
                gb += self.imgs[i].nbytes
                pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
            pbar.close()  
  1. imgs ,初始化图像缓存列表self.imgs,长度为n(图像总数),初始值都为None
  2. 如果启用了图像缓存,则执行以下缓存逻辑:
  3. gb,初始化一个变量,用于跟踪已缓存图像的总大小
  4. img_hw0, img_hw,初始化两个列表长度都为n(图像总数),用于存储每张图像的原始尺寸和调整后的尺寸。初始值都为None
  5. results ,使用线程池ThreadPool并行加载图像,以提高加载速度。这里创建了一个包含8个线程的线程池,并对每张图像调用load_image函数进行加载。imap函数用于应用函数到输入序列的每个元素,这里的输入序列是通过zip和range(n)生成的,包含了每张图像的索引
  6. pbar ,使用tqdm创建一个进度条,以可视化图像加载和缓存的进度。enumerate(results)提供了一个带索引的结果序列,total=n指定了总进度条长度为图像总数n
  7. 遍历加载结果,i是图像的索引,x是加载结果(包括图像数据及其尺寸信息)
  8. 将加载的图像数据x分配给self.imgs[i],原始尺寸赋值给self.img_hw0[i],调整后的尺寸赋值给self.img_hw[i],这样,每张图像及其尺寸信息都被缓存起来
  9. gb ,将当前图像数据的字节大小(self.imgs[i].nbytes)加到gb变量上,以更新缓存的总大小
  10. 更新进度条的描述,显示当前已缓存图像的总大小(以GB为单位)。这提供了实时反馈,训练时知道缓存进度和内存使用情况
  11. 完成所有图像的加载和缓存后,关闭进度条。这是确保tqdm进度条正确结束并清理资源的标准做法

load_image函数:

def load_image(self, index):
    # loads 1 image from dataset, returns img, original hw, resized hw
    img = self.imgs[index]
    if img is None:  # not cached
        path = self.img_files[index]
        img = cv2.imread(path)  # BGR
        assert img is not None, 'Image Not Found ' + path
        h0, w0 = img.shape[:2]  # orig hw
        r = self.img_size / max(h0, w0)  # resize image to img_size
        if r != 1:  # always resize down, only resize up if training with augmentation
            interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
            img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
        return img, (h0, w0), img.shape[:2]  # img, hw_original, hw_resized
    else:
        return self.imgs[index], self.img_hw0[index], self.img_hw[index]  # img, hw_original, hw_resized
  • 21
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

机器学习杨卓越

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值