YOLOV5源码解读(数据集加载和增强)

YOLOV5源码解读系列文章目录

  1. 数据集加载和增强
  2. loss计算

前言

此篇为yolov5 3.1 版本,官方地址[https://github.com/ultralytics/yolov5]
看源代码之前有必要先大致了解实现原理和流程,强推这篇文章https://blog.csdn.net/nan355655600/article/details/107852353(https://github.com/amdegroot/ssd.pytorch)


数据加载器由utils/datasets.py文件中的create_dataloader函数创建,其中主要有两个类构成LoadImagesAndLabels:数据集的加载和增强都由这个类实现 InfiniteDataLoader:对DataLoader进行封装,就是为了能够永久持续的采样数据,详细原因这里可以看官方说明[https://github.com/ultralytics/yolov5/pull/876](https://github.com/ultralytics/yolov5/pull/876)

持续采样InfiniteDataLoader

class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
    """ Dataloader that reuses workers

    Uses same syntax as vanilla DataLoader
    """

    """
    这块对DataLoader进行封装,就是为了能够永久持续的采样数据
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
        self.iterator = super().__iter__()

    def __len__(self):
        return len(self.batch_sampler.sampler)

    def __iter__(self):
        for i in range(len(self)):
            yield next(self.iterator)


class _RepeatSampler(object):
    """ Sampler that repeats forever
    永久持续的采样
    Args:
        sampler (Sampler)
    """

    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            yield from iter(self.sampler)

数据加载

class LoadImagesAndLabels(Dataset):  # for training/testing
    def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
                 cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
        """
        path:    数据集路径
        img_size:    图片大小
        batch_size:  批次大小
        augment: 是否数据增强
        hyp: 超参数的yaml文件
        rect: 矩形训练,就是对图片填充灰边(只在高或宽的一边填充)
        image_weights: 图像采样的权重
        cache_images: 图片是否缓存,用于加速训练
        single_cls: 是否是一个类别
        stride: 模型步幅, 图像大小/网络下采样之后的输出大小
        pad: 填充宽度
        rank: 当前进程编号
        """
        self.img_size = img_size
        self.augment = augment
        self.hyp = hyp
        self.image_weights = image_weights
        self.rect = False if image_weights else rect
        # mosaic 将4张图片融合在一张图片里,进行训练
        self.mosaic = self.augment and not self.rect  # load 4 images at a time into a mosaic (only during training)
        self.mosaic_border = [-img_size // 2, -img_size // 2]
        self.stride = stride

        """
        首先读取图像路径,转换合适的格式,根据图像路径,替换其中的images和图片后缀,转换成label路径
        读取coco128/labels/train.cache文件,没有则创建,cache存储字典{图片路径:label路径,图片大小}
        """

        def img2label_paths(img_paths):
            # Define label paths as a function of image paths
            """
            img_paths现在存储了所有的图片路径,只需将路径中的images换成labels,图片后缀改为.txt就得到标注文件的路径
            """
            sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep  # /images/, /labels/ substrings
            return [x.replace(sa, sb, 1).replace(os.path.splitext(x)[-1], '.txt') for x in img_paths]
        # 读取图像路径,转换成合适的格式
        try:
            f = []  # image files
            for p in path if isinstance(path, list) else [path]:
                p = str(Path(p))  # os-agnostic
                parent = str(Path(p).parent) + os.sep   #上级目录  ../coco128/images
                if os.path.isfile(p):  # file
                    with open(p, 'r') as t:
                        t = t.read().splitlines()
                        f += [x.replace('./', parent) if x.startswith('./') else x for x in t]  # local to global path
                elif os.path.isdir(p):  # folder
                    f += glob.iglob(p + os.sep + '*.*')     # 读取images下的所有文件不包含目录
                else:
                    raise Exception('%s does not exist' % p)
            # 将图片的路径改为适合本地系统的格式(windows是'\\', linux是'/'),图片后缀名在img_formats里的就改为小写
            self.img_files = sorted(
                [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats])
            assert len(self.img_files) > 0, 'No images found'
        except Exception as e:
            raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))

        # Check cache
        self.label_files = img2label_paths(self.img_files)  # labels 图片路径到label路径的转换
        cache_path = str(Path(self.label_files[0]).parent) + '.cache'  # cached labels
        """
        读取labels下的.cache文件, 没有则创建, cache里的关键字'hash'是图片+label的文件字节大小之和
        """
        if os.path.isfile(cache_path):
            cache = torch.load(cache_path)  # load
            # 如果cache存储的hash与当前的label+图片大小对应不上,则重新创建.cache文件
            if cache['hash'] != get_hash(self.label_files + self.img_files):  # dataset changed
                cache = self.cache_labels(cache_path)  # re-cache
        else:
            cache = self.cache_labels(cache_path)  # cache

        # Read cache
        cache.pop('hash')  # remove hash
        labels, shapes = zip(*cache.values())
        self.labels = list(labels)      # label
        self.shapes = np.array(shapes, dtype=np.float64)    # 图片大小
        self.img_files = list(cache.keys())  # update   图片路径
        self.label_files = img2label_paths(cache.keys())  # update  更新labels路径,因为可能有一部分图片或label损坏

        """
        根据图片数量划分每批的图片数量
        """
        n = len(shapes)  # number of images     图片数量
        bi = np.floor(np.arange(n) / batch_size).astype(np.int)  # batch index  划分批次
        nb = bi[-1] + 1  # number of batches    批次数量
        self.batch = bi  # batch index of image
        self.n = n

        # Rectangular Training  矩形训练
        """
        先求的图像的宽高比,然后对较长的边缩放到stride的倍数,
        在按照宽高比对短的一边缩放,进行少量的填充也达到stride的最小倍数
        """
        if self.rect:
            # Sort by aspect ratio
            s = self.shapes  # wh
            ar = s[:, 1] / s[:, 0]  # aspect ratio  高宽比
            irect = ar.argsort()    # 按着高宽比从小到大排序
            # 重新排序图片,label路径,真实框, shapes, 宽高比的顺序
            self.img_files = [self.img_files[i] for i in irect]
            self.label_files = [self.label_files[i] for i in irect]
            self.labels = [self.labels[i] for i in irect]
            self.shapes = s[irect]  # wh
            ar = ar[irect]

            # Set training image shapes
            shapes = [[1, 1]] * nb  # [[h/w, 1], [1, w/h]....]
  
yolov5源码是一个用于目标检测的项目,以下是对源码的一些解读: 1. 项目目录结构:源码包含了data、models、utils、train、test等文件夹。其中,data文件夹包含了用于配置数据集的yaml文件和下载数据集的shell命令;models文件夹包含了模型的定义和相关操作;utils文件夹包含了一些辅助函数和工具;train文件夹包含了训练相关的代码;test文件夹包含了测试相关的代码。\[1\] 2. 数据集配置文件:在data文件夹中,可以找到yaml文件,用于配置不同的数据集,如coco、coco128、pascalvoc等。这些配置文件定义了数据集的路径、类别信息、图像大小等。\[1\] 3. 超参数微调配置文件:在data文件夹中,还有一个hyps文件夹,其中的yaml文件用于微调超参数,以优化模型的性能。\[1\] 4. 脚本文件:在scripts文件夹中,存放着下载数据集和权重的shell脚本,可以通过运行这些脚本来获取所需的数据集和权重文件。\[1\] 5. 项目解读:对于项目的解读,可以从项目目录结构开始,了解每个文件的作用和功能。可以先从最基础的文件开始,逐步深入理解代码。同时,可以参考作者提供的英文文档进行解读,也可以参考其他相关资料和教程。\[2\] 总之,yolov5源码是一个用于目标检测的项目,包含了数据集配置、模型定义、训练和测试等功能。通过对源码解读,可以深入理解该项目的实现原理和使用方法。\[1\]\[2\] #### 引用[.reference_title] - *1* [YOLOV5源码的详细解读](https://blog.csdn.net/BGMcat/article/details/120930016)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [YOLOv5源码逐行超详细注释与解读(1)——项目目录结构解析](https://blog.csdn.net/weixin_43334693/article/details/129356033)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值