YOLOV5代码精读之数据处理(dataloader.py)

pytorch数据集处理icon-default.png?t=O83Ahttps://blog.csdn.net/a8039974/article/details/142015862?spm=1001.2014.3001.5502

一、【检测】创建数据加载器(dataloader)

def create_dataloader(path,
                      imgsz,
                      batch_size,
                      stride,
                      single_cls=False,
                      hyp=None,
                      augment=False,
                      cache=False,
                      pad=0.0,
                      rect=False,
                      rank=-1,
                      workers=8,
                      image_weights=False,
                      quad=False,
                      prefix='',
                      shuffle=False):
    if rect and shuffle:
        LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
        shuffle = False
    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
        dataset = LoadImagesAndLabels(
            path,
            imgsz,
            batch_size,
            augment=augment,  # augmentation
            hyp=hyp,  # hyperparameters
            rect=rect,  # rectangular batches
            cache_images=cache,
            single_cls=single_cls,
            stride=int(stride),
            pad=pad,
            image_weights=image_weights,
            prefix=prefix)

    batch_size = min(batch_size, len(dataset))
    nd = torch.cuda.device_count()  # number of CUDA devices
    nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])  # number of workers
    sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
    loader = DataLoader if image_weights else InfiniteDataLoader  # only DataLoader allows for attribute updates
    generator = torch.Generator()
    generator.manual_seed(6148914691236517205 + RANK)
    return loader(dataset,
                  batch_size=batch_size,
                  shuffle=shuffle and sampler is None,
                  num_workers=nw,
                  sampler=sampler,
                  pin_memory=PIN_MEMORY,
                  collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
                  worker_init_fn=seed_worker,
                  generator=generator), dataset

这段代码定义了一个函数 create_dataloader,用于创建一个数据加载器(DataLoader),通常用于深度学习模型的训练和验证过程。下面是对代码的逐步分解和详细解释:

参数列表

  • path: 数据集的路径,可以是图像或标签文件的路径。
  • imgsz: 输入图像的大小,通常是一个整数值,例如640,表示宽度和高度为640的图像。
  • batch_size: 每个批次的样本数量,用于训练时设置。
  • stride: 在图像中滑动的步长,通常与模型架构相关,模型最大是32,【32,16,8】。
  • single_cls: 布尔值,若为真,则表示数据集是否为单类别
  • hyp: 超参数,用于训练过程的调优。
  • augment: 布尔值,若为真,则在数据加载时进行数据增强
  • cache: 布尔值,是否缓存图像数据以加快加载速度。
  • pad: 填充的比例,通常用于处理不同尺寸的图像,设置矩形训练时的填充,默认为0。
  • rect: 布尔值,指明是否采用矩形训练
  • rank: 用于分布式训练的排名标识,多卡训练时的进程编号,rank=-1时GPU=1不进行分布式训练,rank=-1多块GPU使用DataParallel模式,默认是-1。
  • workers: 数据加载的线程数,加载CPU的线程数。
  • image_weights: 布尔值,指示是否根据图像的权重进行加载。
  • quad: 布尔值,控制是否进行四图像模式的加载。
  • prefix: 用于日志的字符串前缀。
  • shuffle: 布尔值,指示是否在每个epoch中打乱数据顺序

冲突警告

if rect and shuffle:
    LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
    shuffle = False

这段代码检查如果同时设置了 rect 和 shuffle,则会发出警告并将 shuffle 设置为 False,因为矩形批处理与打乱数据不兼容。

数据集初始化

with torch_distributed_zero_first(rank):
    dataset = LoadImagesAndLabels(
        path,
        imgsz,
        batch_size,
        augment=augment,
        hyp=hyp,
        rect=rect,
        cache_images=cache,
        single_cls=single_cls,
        stride=int(stride),
        pad=pad,
        image_weights=image_weights,
        prefix=prefix)

使用 LoadImagesAndLabels 类加载数据集,该类负责图像和标签的加载。使用 torch_distributed_zero_first(rank) 允许分布式训练中仅初始化一次数据集。

设置批次大小

batch_size = min(batch_size, len(dataset))

设置批次大小为实际数据集中样本数量的最小值,以避免超出范围的情况。

确定可用的工作线程数

nd = torch.cuda.device_count()  # number of CUDA devices
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])

计算可用的 CUDA 设备数量和 CPU 核心数,并确定使用的工作线程数。

设置采样器

sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)

如果是在分布式训练中,则使用 DistributedSampler 来分配数据。

选择数据加载器

loader = DataLoader if image_weights else InfiniteDataLoader

根据是否使用图像权重选择数据加载器,如果需要更新属性则使用标准的 DataLoader,否则使用无限数据加载器 InfiniteDataLoader

创建随机数生成器

generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)

初始化一个随机数生成器,并设置随机种子,以确保每次运行的随机结果一致。

返回数据加载器和数据集

return loader(dataset,
              batch_size=batch_size,
              shuffle=shuffle and sampler is None,
              num_workers=nw,
              sampler=sampler,
              pin_memory=PIN_MEMORY,
              collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
              worker_init_fn=seed_worker,
              generator=generator), dataset

最后,返回创建的数据加载器和数据集对象。数据加载器被配置为使用可能的数据分布、批次大小、是否打乱等选项。

create_dataloader 函数主要功能是创建一个适用于深度学习训练和验证的数据加载器,支持多种配置选项,包括数据增强、缓存、是否为单类等。通过灵活地使用多线程和分布式采样,它提升了数据加载的效率,确保了训练过程中数据的有效利用。该函数封装了数据准备的复杂性,使得后续模型训练能够更集中于核心逻辑。

二、【检测】LoadImagesAndLabels

class LoadImagesAndLabels(Dataset):
    # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
    cache_version = 0.6  # dataset labels *.cache version
    rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]

    def __init__(self,
                 path,
                 img_size=640,
                 batch_size=16,
                 augment=False,
                 hyp=None,
                 rect=False,
                 image_weights=False,
                 cache_images=False,
                 single_cls=False,
                 stride=32,
                 pad=0.0,
                 min_items=0,
                 prefix=''):
        self.img_size = img_size
        self.augment = augment
        self.hyp = hyp
        self.image_weights = image_weights
        self.rect = False if image_weights else rect
        self.mosaic = self.augment and not self.rect  # load 4 images at a time into a mosaic (only during training)
        self.mosaic_border = [-img_size // 2, -img_size // 2]
        self.stride = stride
        self.path = path
        self.albumentations = Albumentations(size=img_size) if augment else None

        print("dataloader",img_size, batch_size, augment, hyp, rect, image_weights, cache_images, single_cls, stride, pad, min_items, prefix)
        try:
            f = []  # image files
            for p in path if isinstance(path, list) else [path]:
                p = Path(p)  # os-agnostic
                print("path",p)
                if p.is_dir():  # dir
                    f += glob.glob(str(p / '**' / '*.*'), recursive=True)
                    # f = list(p.rglob('*.*'))  # pathlib
                elif p.is_file():  # file
                    with open(p) as t:
                        t = t.read().strip().splitlines()
                        parent = str(p.parent) + os.sep
                        f += [x.replace('./', parent, 1) if x.startswith('./') else x for x in t]  # to global path
                        # f += [p.parent / x.lstrip(os.sep) for x in t]  # to global path (pathlib)
                else:
                    raise FileNotFoundError(f'{prefix}{p} does not exist')
            self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
            # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS])  # pathlib
            assert self.im_files, f'{prefix}No images found'
        except Exception as e:
            raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e

        # Check cache
        self.label_files = img2label_paths(self.im_files)  # labels
        cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
        try:
            cache, exists = np.load(cache_path, allow_pickle=True).item(), True  # load dict
            assert cache['version'] == self.cache_version  # matches current version
            assert cache['hash'] == get_hash(self.label_files + self.im_files)  # identical hash
        except Exception:
            cache, exists = self.cache_labels(cache_path, prefix), False  # run cache ops

        # Display cache
        nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupt, total
        if exists and LOCAL_RANK in {-1, 0}:
            d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
            tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT)  # display cache results
            if cache['msgs']:
                LOGGER.info('\n'.join(cache['msgs']))  # display warnings
        assert nf > 0 or not augment, f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}'

        # Read cache
        [cache.pop(k) for k in ('hash', 'version', 'msgs')]  # remove items
        labels, shapes, self.segments = zip(*cache.values())
        nl = len(np.concatenate(labels, 0))  # number of labels
        assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}'
        self.labels = list(labels)
        self.shapes = np.array(shapes)
        self.im_files = list(cache.keys())  # update
        self.label_files = img2label_paths(cache.keys())  # update

        # Filter images
        if min_items:
            include = np.array([len(x) >= min_items for x in self.labels]).nonzero()[0].astype(int)
            LOGGER.info(f'{prefix}{n - len(include)}/{n} images filtered from dataset')
            self.im_files = [self.im_files[i] for i in include]
            self.label_files = [self.label_files[i] for i in include]
            self.labels = [self.labels[i] for i in include]
            self.segments = [self.segments[i] for i in include]
            self.shapes = self.shapes[include]  # wh

        # Create indices
        n = len(self.shapes)  # number of images
        bi = np.floor(np.arange(n) / batch_size).astype(int)  # batch index
        nb = bi[-1] + 1  # number of batches
        self.batch = bi  # batch index of image
        self.n = n
        self.indices = range(n)

        # Update labels
        include_class = []  # filter labels to include only these classes (optional)
        include_class_array = np.array(include_class).reshape(1, -1)
        for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
            if include_class:
                j = (label[:, 0:1] == include_class_array).any(1)
                self.labels[i] = label[j]
                if segment:
                    self.segments[i] = segment[j]
            if single_cls:  # single-class training, merge all classes into 0
                self.labels[i][:, 0] = 0
                if segment:
                    self.segments[i][:, 0] = 0

        # Rectangular Training
        if self.rect:
            # Sort by aspect ratio
            s = self.shapes  # wh
            ar = s[:, 1] / s[:, 0]  # aspect ratio
            irect = ar.argsort()
            self.im_files = [self.im_files[i] for i in irect]
            self.label_files = [self.label_files[i] for i in irect]
            self.labels = [self.labels[i] for i in irect]
            self.segments = [self.segments[i] for i in irect]
            self.shapes = s[irect]  # wh
            ar = ar[irect]

            # Set training image shapes
            shapes = [[1, 1]] * nb
            for i in range(nb):
                ari = ar[bi == i]
                mini, maxi = ari.min(), ari.max()
                if maxi < 1:
                    shapes[i] = [maxi, 1]
                elif mini > 1:
                    shapes[i] = [1, 1 / mini]

            self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride

        # Cache images into RAM/disk for faster training
        if cache_images == 'ram' and not self.check_cache_ram(prefix=prefix):
            cache_images = False
        self.ims = [None] * n
        self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
        if cache_images:
            b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
            self.im_hw0, self.im_hw = [None] * n, [None] * n
            fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
            results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
            pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
            for i, x in pbar:
                if cache_images == 'disk':
                    b += self.npy_files[i].stat().st_size
                else:  # 'ram'
                    self.ims[i], self.im_hw0[i], self.im_hw[i] = x  # im, hw_orig, hw_resized = load_image(self, i)
                    b += self.ims[i].nbytes
                pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
            pbar.close()

    def check_cache_ram(self, safety_margin=0.1, prefix=''):
        # Check image caching requirements vs available memory
        b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
        n = min(self.n, 30)  # extrapolate from 30 random images
        for _ in range(n):
            im = cv2.imread(random.choice(self.im_files))  # sample image
            ratio = self.img_size / max(im.shape[0], im.shape[1])  # max(h, w)  # ratio
            b += im.nbytes * ratio ** 2
        mem_required = b * self.n / n  # GB required to cache dataset into RAM
        mem = psutil.virtual_memory()
        cache = mem_required * (1 + safety_margin) < mem.available  # to cache or not to cache, that is the question
        if not cache:
            LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, "
                        f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, "
                        f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
        return cache

    def cache_labels(self, path=Path('./labels.cache'), prefix=''):
        # Cache dataset labels, check images and read shapes
        x = {}  # dict
        nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number missing, found, empty, corrupt, messages
        desc = f"{prefix}Scanning {path.parent / path.stem}..."
        with Pool(NUM_THREADS) as pool:
            pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
                        desc=desc,
                        total=len(self.im_files),
                        bar_format=TQDM_BAR_FORMAT)
            for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
                nm += nm_f
                nf += nf_f
                ne += ne_f
                nc += nc_f
                if im_file:
                    x[im_file] = [lb, shape, segments]
                if msg:
                    msgs.append(msg)
                pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"

        pbar.close()
        if msgs:
            LOGGER.info('\n'.join(msgs))
        if nf == 0:
            LOGGER.warning(f'{prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
        x['hash'] = get_hash(self.label_files + self.im_files)
        x['results'] = nf, nm, ne, nc, len(self.im_files)
        x['msgs'] = msgs  # warnings
        x['version'] = self.cache_version  # cache version
        try:
            np.save(path, x)  # save cache for next time
            path.with_suffix('.cache.npy').rename(path)  # remove .npy suffix
            LOGGER.info(f'{prefix}New cache created: {path}')
        except Exception as e:
            LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}')  # not writeable
        return x

    def __len__(self):
        return len(self.im_files)

    # def __iter__(self):
    #     self.count = -1
    #     print('ran dataset iter')
    #     #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
    #     return self

    def __getitem__(self, index):
        index = self.indices[index]  # linear, shuffled, or image_weights

        hyp = self.hyp
        mosaic = self.mosaic and random.random() < hyp['mosaic']
        if mosaic:
            # Load mosaic
            img, labels = self.load_mosaic(index)
            shapes = None

            # MixUp augmentation
            if random.random() < hyp['mixup']:
                img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))

        else:
            # Load image
            img, (h0, w0), (h, w) = self.load_image(index)

            # Letterbox
            shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size  # final letterboxed shape
            img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
            shapes = (h0, w0), ((h / h0, w / w0), pad)  # for COCO mAP rescaling

            labels = self.labels[index].copy()
            if labels.size:  # normalized xywh to pixel xyxy format
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])

            if self.augment:
                img, labels = random_perspective(img,
                                                 labels,
                                                 degrees=hyp['degrees'],
                                                 translate=hyp['translate'],
                                                 scale=hyp['scale'],
                                                 shear=hyp['shear'],
                                                 perspective=hyp['perspective'])

        nl = len(labels)  # number of labels
        if nl:
            labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)

        if self.augment:
            # Albumentations
            img, labels = self.albumentations(img, labels)
            nl = len(labels)  # update after albumentations

            # HSV color-space
            augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])

            # Flip up-down
            if random.random() < hyp['flipud']:
                img = np.flipud(img)
                if nl:
                    labels[:, 2] = 1 - labels[:, 2]

            # Flip left-right
            if random.random() < hyp['fliplr']:
                img = np.fliplr(img)
                if nl:
                    labels[:, 1] = 1 - labels[:, 1]

            # Cutouts
            # labels = cutout(img, labels, p=0.5)
            # nl = len(labels)  # update after cutout

        labels_out = torch.zeros((nl, 6))
        if nl:
            labels_out[:, 1:] = torch.from_numpy(labels)

        # Convert
        img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
        img = np.ascontiguousarray(img)

        return torch.from_numpy(img), labels_out, self.im_files[index], shapes

    def load_image(self, i):
        # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
        im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
        if im is None:  # not cached in RAM
            if fn.exists():  # load npy
                im = np.load(fn)
            else:  # read image
                im = cv2.imread(f)  # BGR
                assert im is not None, f'Image Not Found {f}'
            h0, w0 = im.shape[:2]  # orig hw
            r = self.img_size / max(h0, w0)  # ratio
            if r != 1:  # if sizes are not equal
                interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
                im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
            return im, (h0, w0), im.shape[:2]  # im, hw_original, hw_resized
        return self.ims[i], self.im_hw0[i], self.im_hw[i]  # im, hw_original, hw_resized

    def cache_images_to_disk(self, i):
        # Saves an image as an *.npy file for faster loading
        f = self.npy_files[i]
        if not f.exists():
            np.save(f.as_posix(), cv2.imread(self.im_files[i]))

    def load_mosaic(self, index):
        # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
        labels4, segments4 = [], []
        s = self.img_size
        yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border)  # mosaic center x, y
        indices = [index] + random.choices(self.indices, k=3)  # 3 additional image indices
        random.shuffle(indices)
        for i, index in enumerate(indices):
            # Load image
            img, _, (h, w) = self.load_image(index)

            # place img in img4
            if i == 0:  # top left
                img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
                x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc  # xmin, ymin, xmax, ymax (large image)
                x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h  # xmin, ymin, xmax, ymax (small image)
            elif i == 1:  # top right
                x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
                x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
            elif i == 2:  # bottom left
                x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
                x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
            elif i == 3:  # bottom right
                x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
                x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)

            img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b]  # img4[ymin:ymax, xmin:xmax]
            padw = x1a - x1b
            padh = y1a - y1b

            # Labels
            labels, segments = self.labels[index].copy(), self.segments[index].copy()
            if labels.size:
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh)  # normalized xywh to pixel xyxy format
                segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
            labels4.append(labels)
            segments4.extend(segments)

        # Concat/clip labels
        labels4 = np.concatenate(labels4, 0)
        for x in (labels4[:, 1:], *segments4):
            np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
        # img4, labels4 = replicate(img4, labels4)  # replicate

        # Augment
        img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
        img4, labels4 = random_perspective(img4,
                                           labels4,
                                           segments4,
                                           degrees=self.hyp['degrees'],
                                           translate=self.hyp['translate'],
                                           scale=self.hyp['scale'],
                                           shear=self.hyp['shear'],
                                           perspective=self.hyp['perspective'],
                                           border=self.mosaic_border)  # border to remove

        return img4, labels4

    def load_mosaic9(self, index):
        # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
        labels9, segments9 = [], []
        s = self.img_size
        indices = [index] + random.choices(self.indices, k=8)  # 8 additional image indices
        random.shuffle(indices)
        hp, wp = -1, -1  # height, width previous
        for i, index in enumerate(indices):
            # Load image
            img, _, (h, w) = self.load_image(index)

            # place img in img9
            if i == 0:  # center
                img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
                h0, w0 = h, w
                c = s, s, s + w, s + h  # xmin, ymin, xmax, ymax (base) coordinates
            elif i == 1:  # top
                c = s, s - h, s + w, s
            elif i == 2:  # top right
                c = s + wp, s - h, s + wp + w, s
            elif i == 3:  # right
                c = s + w0, s, s + w0 + w, s + h
            elif i == 4:  # bottom right
                c = s + w0, s + hp, s + w0 + w, s + hp + h
            elif i == 5:  # bottom
                c = s + w0 - w, s + h0, s + w0, s + h0 + h
            elif i == 6:  # bottom left
                c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
            elif i == 7:  # left
                c = s - w, s + h0 - h, s, s + h0
            elif i == 8:  # top left
                c = s - w, s + h0 - hp - h, s, s + h0 - hp

            padx, pady = c[:2]
            x1, y1, x2, y2 = (max(x, 0) for x in c)  # allocate coords

            # Labels
            labels, segments = self.labels[index].copy(), self.segments[index].copy()
            if labels.size:
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady)  # normalized xywh to pixel xyxy format
                segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
            labels9.append(labels)
            segments9.extend(segments)

            # Image
            img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:]  # img9[ymin:ymax, xmin:xmax]
            hp, wp = h, w  # height, width previous

        # Offset
        yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border)  # mosaic center x, y
        img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]

        # Concat/clip labels
        labels9 = np.concatenate(labels9, 0)
        labels9[:, [1, 3]] -= xc
        labels9[:, [2, 4]] -= yc
        c = np.array([xc, yc])  # centers
        segments9 = [x - c for x in segments9]

        for x in (labels9[:, 1:], *segments9):
            np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
        # img9, labels9 = replicate(img9, labels9)  # replicate

        # Augment
        img9, labels9, segments9 = copy_paste(img9, labels9, segments9, p=self.hyp['copy_paste'])
        img9, labels9 = random_perspective(img9,
                                           labels9,
                                           segments9,
                                           degrees=self.hyp['degrees'],
                                           translate=self.hyp['translate'],
                                           scale=self.hyp['scale'],
                                           shear=self.hyp['shear'],
                                           perspective=self.hyp['perspective'],
                                           border=self.mosaic_border)  # border to remove

        return img9, labels9

    @staticmethod
    def collate_fn(batch):
        im, label, path, shapes = zip(*batch)  # transposed
        for i, lb in enumerate(label):
            lb[:, 0] = i  # add target image index for build_targets()
        return torch.stack(im, 0), torch.cat(label, 0), path, shapes

    @staticmethod
    def collate_fn4(batch):
        im, label, path, shapes = zip(*batch)  # transposed
        n = len(shapes) // 4
        im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]

        ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
        wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
        s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]])  # scale
        for i in range(n):  # zidane torch.zeros(16,3,720,1280)  # BCHW
            i *= 4
            if random.random() < 0.5:
                im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
                                    align_corners=False)[0].type(im[i].type())
                lb = label[i]
            else:
                im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
                lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
            im4.append(im1)
            label4.append(lb)

        for i, lb in enumerate(label4):
            lb[:, 0] = i  # add target image index for build_targets()

        return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4

2.1 类定义:LoadImagesAndLabels

  • class LoadImagesAndLabels(Dataset):

    这是一个定义在 YOLOv5 中用于加载图像和标签的类,继承自 Dataset

LoadImagesAndLabels 類的主要功能是从给定的路径中加载图像及其相应的标签并进行必要的预处理和增强。它为 YOLOv5 模型的训练提供了高效的数据加载解决方案,包括:

  • 支持多种数据增强方法,提高模型的泛化能力。
  • 通过缓存机制和形状调整优化训练数据加载速度。
  • 能够处理多类和单类训练任务,并允许自定义数据预处理。

这个类在目标检测任务中非常关键,能够确保输入的数据质量,提升模型的训练效率。

  类属性

  • cache_version = 0.6

    版本控制,用于标识数据集标签缓存的版本。

  • rand_interp_methods:

    该列表包含用于图像插值的方法,使用 OpenCV 的标志来进行图像大小调整时选择。

2.2 初始化方法 __init__

  • 参数:

    • path: 数据集图像和标签的路径。
    • img_size: 指定图像的大小。
    • batch_size: 每个批次加载的图像数量。
    • augment: 是否使用数据增强。
    • hyp: 超参数,用于数据增强的配置。
    • 其他参数定义细节,比如图像的 padding,是否单类训练等。
  • 在初始化中,尝试从给定路径加载图像并相应地处理标签。它会处理如下几种情况:

    • 若路径为目录,则遍历目录中的所有图像文件,添加到 self.im_files 列表。
    • 若路径为文件,则读取文件内容,解析出图像路径,加到列表中。
    • 会检查是否有找到有效的图像,否则抛出异常。
  • 检查是否存在缓存的标签数据以加快加载速度。如果缓存存在,也会检查其版本和哈希值。

  • 解析标签和形状数据,并使用它们进行训练时所需的图像过滤。

init主要功能:

  1. 赋值一些基础的self变量 用于后面在getitem中调用
  2. 得到path路径下的所有图片的路径self.img_files
  3. 根据imgs路径找到labels的路径self.label_files
  4. cache label
  5. Read cache 生成self.labels、self.shapes、self.img_files、self.label_files、self.batch、self.n、self.indices等变量
  6. 为Rectangular Training作准备: 生成self.batch_shapes
  7. 是否需要cache image(一般不需要,太大了)

2.2.1 参数设置

self.img_size = img_size
self.augment = augment
self.hyp = hyp
self.image_weights = image_weights
self.rect = False if image_weights else rect
self.mosaic = self.augment and not self.rect  # load 4 images at a time into a mosaic (only during training)
self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride
self.path = path
self.albumentations = Albumentations(size=img_size) if augment else None

这段代码是一个类初始化器中的一些属性设置部分,主要用于配置与数据加载和增强相关的参数。下面是逐步分解和详细解释:

  1. self.img_size = img_size
    这行代码将传入的 img_size 参数赋值给实例变量 self.img_size,表示图像的大小。通常用于指定输入图像的高度和宽度。

  2. self.augment = augment
    这行代码将是否进行数据增强的布尔值 augment 赋值给实例变量 self.augment。若此值为 True,则在训练过程中对图像进行一些随机变换,以提高模型的鲁棒性。

  3. self.hyp = hyp
    此行将包含超参数的字典 hyp 赋值给实例变量 self.hyp。这些超参数可能用于调整数据增强的具体策略,如旋转角度、缩放比例等。

  4. self.image_weights = image_weights
    这行代码将图像权重的布尔值传入并赋值给 self.image_weights。如果为 True,则在数据加载时会考虑每张图像的权重,通常用于不平衡的数据集。

  5. self.rect = False if image_weights else rect
    通过条件表达式,如果 image_weights 为 True,则将 self.rect 设置为 False;否则,将 self.rect 设置为传入的 rect 值。这表示矩形批次的处理会受到图像权重的影响。

  6. self.mosaic = self.augment and not self.rect
    这行代码设定了 self.mosaic,用于控制在训练时是否使用马赛克增强。只有当进行数据增强且未使用矩形批次时,self.mosaic 才会被设置为 True,这意味着在训练时会将四张图像合并为一张图像以提高训练效果。

  7. self.mosaic_border = [-img_size // 2, -img_size // 2]
    这行代码设置了马赛克合成时的边界,通常用于在合成时保持图像的中心位置,从而避免合成图像出现空白区域。

  8. self.stride = stride
    将步幅(stride)参数赋值给 self.stride,步幅通常用于在卷积神经网络中控制卷积的滑动步长。

  9. self.path = path
    将数据集路径的参数传入并赋值给 self.path,用于指定图像和标签所在的目录或文件。

  10. self.albumentations = Albumentations(size=img_size) if augment else None
    这行代码创建数据增强的实例 Albumentations,并根据 img_size 设置图像尺寸。如果 augment 为 False,则 self.albumentations 为 None,这表明在训练中不进行增强。

这段代码主要功能是初始化数据加载相关的配置和参数,特别是与图像大小、数据增强、图像权重及其处理等方面直接相关。主要是为后面的图像数据加载和预处理步骤做好准备,确保在训练过程中能够有效地增强数据,提高模型的性能和泛化能力。

 2.2.2 加载图像文件

 try:
            f = []  # image files
            for p in path if isinstance(path, list) else [path]:
                p = Path(p)  # os-agnostic
                print("path",p)
                if p.is_dir():  # dir
                    f += glob.glob(str(p / '**' / '*.*'), recursive=True)
                    # f = list(p.rglob('*.*'))  # pathlib
                elif p.is_file():  # file
                    with open(p) as t:
                        t = t.read().strip().splitlines()
                        parent = str(p.parent) + os.sep
                        f += [x.replace('./', parent, 1) if x.startswith('./') else x for x in t]  # to global path
                        # f += [p.parent / x.lstrip(os.sep) for x in t]  # to global path (pathlib)
                else:
                    raise FileNotFoundError(f'{prefix}{p} does not exist')
            self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
            # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS])  # pathlib
            assert self.im_files, f'{prefix}No images found'
        except Exception as e:
            raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e

这段代码的主要功能是从给定的路径中加载图像文件。下面将逐步分解并详细解释这段代码:

  1. 初始化文件列表:

    f = []  # image files
    

    创建一个空列表 f 用于存储找到的图像文件路径。

  2. 处理路径:

    for p in path if isinstance(path, list) else [path]:
    

    这行代码检查 path 是否是一个列表,如果是,则遍历该列表;如果不是,则将其包装成一个列表进行遍历。

  3. 路径转换:

    p = Path(p)  # os-agnostic
    

    使用 Path 对象来确保路径的操作与操作系统无关。

  4. 打印路径:

    print("path", p)
    

    输出当前处理的路径,帮助调试和跟踪。

  5. 检查路径类型并加载文件:

    if p.is_dir():  # dir
        f += glob.glob(str(p / '**' / '*.*'), recursive=True)
    
    • 如果 p 是一个目录,则使用 glob 模块查找该目录下(包括子目录)所有的文件。'**/*.*'表示匹配所有类型的文件。
    elif p.is_file():  # file
        with open(p) as t:
            t = t.read().strip().splitlines()
            parent = str(p.parent) + os.sep
            f += [x.replace('./', parent, 1) if x.startswith('./') else x for x in t]  # to global path
    
    • 如果 p 是一个文件,则打开该文件,读取其内容,并将每一行(文件路径)存储在列表 t 中。接着,构建其父目录的路径,处理路径以确保它们是绝对路径。如果路径以 ./ 开头,则将其替换为父目录路径。
  6. 处理无效路径:

    else:
        raise FileNotFoundError(f'{prefix}{p} does not exist')
    
    • 如果 p 既不是目录也不是文件,则抛出 FileNotFoundError,指示该路径不存在。
  7. 过滤和排序图像文件:

    self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
    
    • 将 f 中的文件路径根据扩展名滤出有效的图像文件(与 IMG_FORMATS 变量中定义的格式匹配),并用当前操作系统的路径分隔符替换路径中的斜杠,同时对文件路径进行排序。
  8. 检查是否找到图像:

    assert self.im_files, f'{prefix}No images found'
    
    • 确保至少找到了一个图像文件。如果没有找到,断言将失败并抛出异常。
  9. 异常处理:

    except Exception as e:
        raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e
    
    • 捕获所有异常,并抛出一个新的异常,包含上下文信息,便于排查问题。

这段代码的主要功能是从指定的路径(可能是文件或目录)中加载图像文件。它首先检查路径的类型,然后递归地查找所有图像文件,存储有效的图像文件路径并进行排序,确保没有遗漏。若发生错误则抛出描述性异常。这段代码是数据加载流程中的一个重要环节,确保后续处理步骤有正确的文件输入。

2.2.3 检查数据集的缓存

        # Check cache
        self.label_files = img2label_paths(self.im_files)  # labels
        cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
        try:
            cache, exists = np.load(cache_path, allow_pickle=True).item(), True  # load dict
            assert cache['version'] == self.cache_version  # matches current version
            assert cache['hash'] == get_hash(self.label_files + self.im_files)  # identical hash
        except Exception:
            cache, exists = self.cache_labels(cache_path, prefix), False  # run cache ops

这段代码的功能是检查数据集的缓存(cache),以提高数据加载的效率。下面是逐步分解和详细解释该代码的每一部分:

  1. 获取标签文件路径

    self.label_files = img2label_paths(self.im_files)  # labels
    

    这行代码调用了 img2label_paths 函数,将图像文件的路径转换为对应的标签文件路径。self.im_files 是一个包含图像文件路径的列表,self.label_files 将保存其对应的标签文件路径。

  2. 定义缓存路径

    cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
    

    这段代码根据 p (可能是数据集的路径)是否是一个文件来决定缓存路径。如果 p 是一个文件,则直接使用 p 作为缓存路径;否则,使用标签文件的第一条路径的父目录作为基础路径,并将其后缀改为 .cache,来定义一个新的缓存文件路径。

  3. 加载缓存

    try:
        cache, exists = np.load(cache_path, allow_pickle=True).item(), True  # load dict
    

    这行代码尝试使用 numpy 的 load 函数从上一步定义的 cache_path 中加载缓存文件。allow_pickle=True 允许加载包含 Python 对象的缓存。如果成功加载,将返回一个字典 cache,并将 exists 设为 True

  4. 验证缓存版本和内容

    assert cache['version'] == self.cache_version  # matches current version
    assert cache['hash'] == get_hash(self.label_files + self.im_files)  # identical hash
    

    这两行代码使用 assert 语句检查缓存的版本是否与当前代码使用的缓存版本匹配(self.cache_version),并检查缓存中存储的文件哈希是否与当前图像和标签文件的哈希相同。如果任何断言失败,则会引发异常。

  5. 异常处理

    except Exception:
        cache, exists = self.cache_labels(cache_path, prefix), False  # run cache ops
    

    如果以上尝试加载缓存的操作失败(例如缓存文件不存在、版本不匹配等),则进入 except 块。这里调用 self.cache_labels 方法重新生成标签的缓存,并将 exists 设置为 False,表示缓存文件不存在。

这段代码的主要功能是检查和处理数据集的缓存文件,以加速后续的数据加载过程。通过加载已有的缓存,避免重复计算和IO操作,当缓存有效且匹配时,可以有效提升数据处理的性能;如果缓存无效或缺失,则会调用 .cache_labels() 方法重新生成缓存。这种机制确保了在训练模型或进行数据处理时,能够高效地访问标签信息。

 2.2.4 扫描结果

        # Display cache
        nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupt, total
        if exists and LOCAL_RANK in {-1, 0}:
            d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
            tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT)  # display cache results
            if cache['msgs']:
                LOGGER.info('\n'.join(cache['msgs']))  # display warnings
        assert nf > 0 or not augment, f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}'

这段代码主要用于处理和展示数据集缓存的扫描结果。我们逐步分解并详细解释代码的每个部分:

  1. 获取缓存结果:

    nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupt, total
    

    这行代码通过 cache.pop('results') 从缓存中提取出一个元组,分别表示:

    • nf: 找到的有效图像数量。
    • nm: 找不到的图像数量。
    • ne: 空白图像的数量。
    • nc: 损坏图像的数量。
    • n: 总图像数量。
  2. 检查条件并打印扫描进度:

    if exists and LOCAL_RANK in {-1, 0}:
        d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
        tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT)  # display cache results
    
    • if exists and LOCAL_RANK in {-1, 0}: 这行代码检查缓存文件是否存在,并且检查当前进程是否是主进程(即在分布式训练中,rank 为 -1 或 0)。
    • 如果条件为真,它将生成描述信息,并使用 tqdm 来显示进度条,显示扫描的状态信息,包括总共有多少有效图像、背景图像的数量以及损坏图像的数量。
  3. 记录警告信息:

    if cache['msgs']:
        LOGGER.info('\n'.join(cache['msgs']))  # display warnings
    

    这一段代码会检查缓存中是否有警告信息(cache['msgs']),如果有,就通过日志记录器 LOGGER 将警告信息输出。

  4. 确保找到有效标签:

    assert nf > 0 or not augment, f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}'
    

    最后这条断言确保找到的有效标签数量 nf 大于零,或者如果不执行数据增强(not augment),则允许继续。如果没有找到有效标签,将抛出一个AssertionError,指出在指定的 cache_path 中未找到标签,而无法开始训练。并提供帮助链接 HELP_URL

这段代码的主要功能是展示数据集缓存的扫描结果,并记录相关的警告信息。它确保在开始训练之前,找到了足够数量的有效标签,提示用户进行检查和调整。整体流程对于模型训练准备工作至关重要,确保所需的数据正常可用。

2.2.5 读取数据

       # Read cache
        [cache.pop(k) for k in ('hash', 'version', 'msgs')]  # remove items
        labels, shapes, self.segments = zip(*cache.values())
        nl = len(np.concatenate(labels, 0))  # number of labels
        assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}'
        self.labels = list(labels)
        self.shapes = np.array(shapes)
        self.im_files = list(cache.keys())  # update
        self.label_files = img2label_paths(cache.keys())  # update

这段代码的主要功能是从缓存中读取数据,并进行一些必要的数据处理和验证,以便为后续的训练或处理阶段做好准备。下面是对这个代码片段的逐步分解和详细解释:

  1. 读取缓存并移除多余的项

    [cache.pop(k) for k in ('hash', 'version', 'msgs')]  # remove items
    

    这行代码使用列表推导式遍历元组 ('hash', 'version', 'msgs'),从 cache 字典中移除这些不必要的项。列表推导式的结果并未被使用,这种情况下,使用普通的 for 循环会更合理。移除这些项通常是因为它们不需要在后续处理中使用。

  2. 解压缓存中的值

    labels, shapes, self.segments = zip(*cache.values())
    

    这行代码将 cache 字典的所有值进行解压和拆分。zip 函数将 cache.values() 的结果(这通常是一个迭代器,包含多个元组或列表)打包成元组。这意味着对于每一个缓存条目,它会返回三个部分:labelsshapes 和 segments,分别存储不同的信息。

  3. 计算标签数量

    nl = len(np.concatenate(labels, 0))  # number of labels
    

    这行代码将所有的标签(labels)拼接到一起,并计算标签的总数量。np.concatenate(labels, 0) 将列表中的数组沿着第一个维度(行)拼接,因此 nl 变量将包含所有标签的总数量。

  4. 验证标签的数量

    assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}'
    

    这行代码是一个断言语句,用于检查标签的数量是否大于零。如果 nl 等于 0 并且 augment(是否进行数据增强)为真,那么将抛出一个 AssertionError,提示所有标签都为空而无法开始训练。这确保了在开始训练之前,确实有标签可以使用。

  5. 更新类的属性

    self.labels = list(labels)
    self.shapes = np.array(shapes)
    self.im_files = list(cache.keys())  # update
    self.label_files = img2label_paths(cache.keys())  # update
    
    • self.labels = list(labels):将解压出来的 labels 转换成列表并赋值给类的属性 self.labels
    • self.shapes = np.array(shapes):将解压出来的 shapes 转换为 NumPy 数组,方便后续的数值计算和处理。
    • self.im_files = list(cache.keys()):更新 self.im_files,将缓存中的所有键(即图像文件路径)转换为列表。
    • self.label_files = img2label_paths(cache.keys()):调用 img2label_paths 函数,将图像文件路径转换为相应的标签路径,更新为 self.label_files

这段代码主要用于从一个缓存中读取和处理数据,以为后续的训练准备必要的信息。这包括清理缓存字典中的无用项、提取标签及其形状信息、验证标签的有效性,以及更新图像和标签的文件路径。通过这些步骤,代码确保了训练时数据的完整性和有效性,从而避免了在没有标签的情况下启动训练的错误。

 2.2.6 最小项目数量

        # Filter images
        if min_items:
            include = np.array([len(x) >= min_items for x in self.labels]).nonzero()[0].astype(int)
            LOGGER.info(f'{prefix}{n - len(include)}/{n} images filtered from dataset')
            self.im_files = [self.im_files[i] for i in include]
            self.label_files = [self.label_files[i] for i in include]
            self.labels = [self.labels[i] for i in include]
            self.segments = [self.segments[i] for i in include]
            self.shapes = self.shapes[include]  # wh

这段代码的作用是根据一个指定的最小项目数量 (min_items) 过滤数据集中图像和对应的标签。下面是对代码的逐步分解和详细解释:

  1. 检查最小项目数量:

    if min_items:
    

    这一行检查变量 min_items 是否为真值(非零)。如果 min_items 为零或未定义,则跳过整个过滤过程。

  2. 生成包含图像的索引:

    include = np.array([len(x) >= min_items for x in self.labels]).nonzero()[0].astype(int)
    
    • 使用列表生成式遍历 self.labels 中每个元素 x,用 len(x) >= min_items 判断每个标签列表的长度是否大于等于 min_items
    • 将结果转换为 NumPy 数组。
    • nonzero()[0] 返回一个包含满足条件的索引的数组,表示这些索引位置的标签长度满足最小要求。
    • 通过 astype(int) 将结果转换为整数类型。
  3. 记录过滤的图像数量:

    LOGGER.info(f'{prefix}{n - len(include)}/{n} images filtered from dataset')
    
    • 记录和打印过滤后的图像数量。n 是原始图像的数量,len(include) 是满足条件的图像数量。因此,n - len(include) 是被过滤掉的图像数量。
  4. 更新图像和标签列表:

    self.im_files = [self.im_files[i] for i in include]
    self.label_files = [self.label_files[i] for i in include]
    self.labels = [self.labels[i] for i in include]
    self.segments = [self.segments[i] for i in include]
    self.shapes = self.shapes[include]  # wh
    
    • 这些行通过索引列表 include 更新图像文件、标签文件、标签、段落和形状的列表。
    • 每个更新的列表只保留那些满足最小项目数量要求的条目。

这段代码的主要功能是过滤数据集中的图像和其对应的标签,以确保仅保留那些具有至少 min_items 个标签的图像。这样做的目的是在处理数据时,确保用于训练或验证的数据是高质量的,避免引入过多缺乏标签或标签不完整的图像,从而提高模型的训练效率和准确性。最终,该代码将更新所有相关数据结构,仅保留符合条件的图像及其相应信息。

 2.2.7 创建索引

        # Create indices
        n = len(self.shapes)  # number of images
        bi = np.floor(np.arange(n) / batch_size).astype(int)  # batch index
        nb = bi[-1] + 1  # number of batches
        self.batch = bi  # batch index of image
        self.n = n
        self.indices = range(n)

这段代码的目的是为数据加载器创建索引,以便在训练或验证过程中能够按照批次(batch)处理图像数据。以下是对每一行代码的详细解释:

  1. n = len(self.shapes)

    • 这一行计算当前数据集中图像的数量,并将其存储在变量 n 中。self.shapes 是一个包含每个图像形状的列表,长度即为图像的数量。
  2. bi = np.floor(np.arange(n) / batch_size).astype(int)

    • 这一行生成一个批次索引数组。np.arange(n) 创建一个从 0 到 n-1 的数组,即图像的索引。
    • 将每个索引除以 batch_size,得到每个图像所属的批次编号(batch index)。
    • 使用 np.floor 将结果向下取整,以确保每个图像都分配到一个有效的批次编号。最后,使用 astype(int) 将结果转换为整型。
  3. nb = bi[-1] + 1

    • 这一行计算总批次的数量。bi[-1] 取 bi 数组的最后一个值(即最后一个图像的批次编号),加 1 之后得到的即为总批次的数量。
  4. self.batch = bi

    • 将计算得到的批次索引数组 bi 存储到对象的属性 self.batch 中,便于后续使用。
  5. self.n = n

    • 将图像数量 n 存储到对象的属性 self.n 中,以便其他方法访问和使用。
  6. self.indices = range(n)

    • 这一行创建一个从 0 到 n-1 的范围对象 self.indices,这将用于确保在数据加载时按序或随机顺序访问图像。

这段代码的主要功能是为数据加载器创建批次索引和图像索引。通过这些索引,加载器可以有效地将数据集切分成多个批次,以便在训练和验证过程中对图像进行批量处理。这种批次处理是深度学习中常用的方法,可以提高计算效率,并有助于模型的收敛。

2.2.8 更新标签

        # Update labels
        include_class = []  # filter labels to include only these classes (optional)
        include_class_array = np.array(include_class).reshape(1, -1)
        for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
            if include_class:
                j = (label[:, 0:1] == include_class_array).any(1)
                self.labels[i] = label[j]
                if segment:
                    self.segments[i] = segment[j]
            if single_cls:  # single-class training, merge all classes into 0
                self.labels[i][:, 0] = 0
                if segment:
                    self.segments[i][:, 0] = 0

这段代码的主要功能是更新标签,以便在训练过程中对数据进行过滤和格式调整。

  1. 初始化过滤条件

    include_class = []  # filter labels to include only these classes (optional)
    include_class_array = np.array(include_class).reshape(1, -1)
    
    • 这里定义了一个include_class列表,用于存放需要包含的特定类(可选)。
    • 使用np.array()将其转化为NumPy数组,并且通过reshape(1, -1)将其调整为二维数组形式,以便后续进行比较。
  2. 遍历标签和分段信息

    for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
    
    • 这里使用enumerate同时遍历self.labels(标签列表)和self.segments(分段列表)。i是索引,labelsegment分别是当前索引对应的标签和分段信息。
  3. 根据过滤条件更新标签

    if include_class:
        j = (label[:, 0:1] == include_class_array).any(1)
        self.labels[i] = label[j]
        if segment:
            self.segments[i] = segment[j]
    
    • if include_class:判断是否定义了需要包含的类。
    • 使用布尔索引j来筛选标签,(label[:, 0:1] == include_class_array).any(1)查看每个标签的类是否在include_class_array中,如果在,jTrue
    • 将标签和可能的分段信息更新为经过过滤后的结果。self.labels[i] = label[j]表示只保留符合条件的标签。
  4. 单类训练的处理

    if single_cls:  # single-class training, merge all classes into 0
        self.labels[i][:, 0] = 0
        if segment:
            self.segments[i][:, 0] = 0
    
    • if single_cls:用于检查是否进行单类训练。
    • 在单类训练中,将所有标签的类索引统一设置为0,即所有对象都被视为同一类。
    • 同样地,如果分段信息存在,也将其类索引设置为0。

这段代码的主要功能是对训练数据中的标签进行更新。它首先允许选定特定类的数据进行过滤,通过更新标签列表来去除不需要的类。其次,在进行单类训练时,将所有标签统一为同一类(通常为0),以适应特定的训练需求。这种处理使得数据在模型训练之前能够更好地进行准备,确保模型只关注指定的类或统一的类,从而增强训练的效果和效率。

 2.2.9 矩形训练

        # Rectangular Training
        if self.rect:
            # Sort by aspect ratio
            s = self.shapes  # wh
            ar = s[:, 1] / s[:, 0]  # aspect ratio
            irect = ar.argsort()
            self.im_files = [self.im_files[i] for i in irect]
            self.label_files = [self.label_files[i] for i in irect]
            self.labels = [self.labels[i] for i in irect]
            self.segments = [self.segments[i] for i in irect]
            self.shapes = s[irect]  # wh
            ar = ar[irect]

            # Set training image shapes
            shapes = [[1, 1]] * nb
            for i in range(nb):
                ari = ar[bi == i]
                mini, maxi = ari.min(), ari.max()
                if maxi < 1:
                    shapes[i] = [maxi, 1]
                elif mini > 1:
                    shapes[i] = [1, 1 / mini]

            self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride

这段代码的主要功能是为模型的训练设置处理图像的形状,特别是在进行“矩形训练”时使用的图像的形状调整。下面是代码的逐步分解和详细解释:

  1. 矩形训练的检查

    if self.rect:
    

    这行代码检查是否启用了“矩形训练”。如果 self.rect 为 True,则进入代码块进行图像形状的适配。

  2. 按宽高比排序

    s = self.shapes  # wh
    ar = s[:, 1] / s[:, 0]  # aspect ratio
    irect = ar.argsort()
    
    • s 是一个包含图像原始形状(宽度和高度)的数组。
    • ar 是一个数组,存储每个图像的宽高比(高度/宽度)。
    • irect 使用 argsort() 方法获取宽高比的排序索引,这样可以按宽高比从小到大排列图像。
  3. 根据排序索引重新排列文件和标签

    self.im_files = [self.im_files[i] for i in irect]
    self.label_files = [self.label_files[i] for i in irect]
    self.labels = [self.labels[i] for i in irect]
    self.segments = [self.segments[i] for i in irect]
    self.shapes = s[irect]  # wh
    ar = ar[irect]
    
    • 根据 irect 中的索引重新排列图像文件名、标签文件、标签和分段信息,以确保它们的顺序与新的宽高比顺序相一致。
    • self.shapes 和 ar 也会根据这个索引进行更新,保持一致性。
  4. 设置训练图像的形状

    shapes = [[1, 1]] * nb
    for i in range(nb):
        ari = ar[bi == i]
        mini, maxi = ari.min(), ari.max()
        if maxi < 1:
            shapes[i] = [maxi, 1]
        elif mini > 1:
            shapes[i] = [1, 1 / mini]
    
    • shapes 初始化为一个包含 [1, 1] 元组的列表,大小为批次数 nb
    • 在循环中,通过 bi(批次索引)来选择属于每个批次的图像的宽高比。
    • 接着,使用 min() 和 max() 方法得到该批次内的宽高比的最小值和最大值。
    • 根据最大和最小的宽高比,调整 shapes 中的值:
      • 如果最大宽高比小于 1,说明图像比较“瘦”,将宽高比设置为 [maxi, 1]
      • 如果最小宽高比大于 1,说明图像比较“宽”,将宽高比设置为 [1, 1 / mini]
  5. 计算最终的批量形状

    self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride
    
    • 根据确定的 shapes,计算每个批次的最终形状。
    • 使用 np.ceil() 处理计算结果,确保每个批次的形状向上取整,并应用填充以保证每个图像都符合模型输入的要求。

这段代码主要负责在启用矩形训练时,根据图像的宽高比对训练数据进行处理。它首先根据宽高比对图像进行排序,然后确保图像文件、标签和相应的形状一致。接着,它为每一批次图像设置合适的形状,以确保在训练过程中充分利用输入数据的形状特征。这不仅提高了模型的训练效率,还可能改善模型的性能,因为它适应了不同形状图像的多样性。

 2.2.10 图像缓存到本地

       # Cache images into RAM/disk for faster training
        if cache_images == 'ram' and not self.check_cache_ram(prefix=prefix):
            cache_images = False
        self.ims = [None] * n
        self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
        if cache_images:
            b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
            self.im_hw0, self.im_hw = [None] * n, [None] * n
            fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
            results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
            pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
            for i, x in pbar:
                if cache_images == 'disk':
                    b += self.npy_files[i].stat().st_size
                else:  # 'ram'
                    self.ims[i], self.im_hw0[i], self.im_hw[i] = x  # im, hw_orig, hw_resized = load_image(self, i)
                    b += self.ims[i].nbytes
                pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
            pbar.close()

这段代码的主要功能是将图像缓存到RAM或磁盘中,以提高训练的速度。在深度学习中,图像的读取速度和存储效率至关重要,因此使用缓存技术来减少I/O操作可以显著提高训练效率。下面是对代码的逐步分解和详细解释:

  1. 条件判断

    if cache_images == 'ram' and not self.check_cache_ram(prefix=prefix):
        cache_images = False
    
    • 这部分代码检查cache_images是否等于'ram',并且通过调用self.check_cache_ram()方法检查是否有足够的内存来缓存图像。如果没有足够的内存,则cache_images设置为False,表示不进行RAM缓存。
  2. 初始化图像和.np文件路径

    self.ims = [None] * n
    self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
    
    • 创建一个包含None的列表self.ims,用于存储缓存的图像,长度为图像数量n
    • 创建self.npy_files,这是一个包含相应图像文件路径的列表,但是扩展名从原来的格式更改为.npy,表示这些文件将用于存储缓存图像的NumPy数组格式。
  3. 缓存图像

    if cache_images:
    
    • 在此条件下,如果cache_imagesTrue,则进入缓存流程。
  4. 初始化缓存变量

    b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
    self.im_hw0, self.im_hw = [None] * n, [None] * n
    
    • b用于跟踪缓存图像的总字节数,初始值为0。
    • gb表示1GB的字节数(1 << 30表示2的30次方,即1073741824)。
    • self.im_hw0 和 self.im_hw 是用于存储图像的原始高度和宽度的列表。
  5. 选择缓存函数

    fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
    
    • 根据cache_images的值,确定用于缓存图像的函数。如果cache_images'disk',则使用self.cache_images_to_disk,否则使用图像加载函数self.load_image
  6. 并行处理

    results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
    
    • 使用线程池并行执行图像缓存操作,这里NUM_THREADS是线程的数量。imap方法会对从0到n的每一个索引调用选定的函数fcn
  7. 进度条显示和累加缓存大小

    pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
    
    • tqdm用于显示处理进度条,enumerate(results)将返回带有进度的结果。
    • 通过遍历结果更新缓存总大小b,并根据模式(RAM或磁盘)分别处理:
    for i, x in pbar:
        if cache_images == 'disk':
            b += self.npy_files[i].stat().st_size
        else:  # 'ram'
            self.ims[i], self.im_hw0[i], self.im_hw[i] = x  # im, hw_orig, hw_resized = load_image(self, i)
            b += self.ims[i].nbytes
        pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
    
    • 如果是从磁盘缓存,则根据.npy文件的大小累加字节数。
    • 如果是缓存到RAM,则更新相应的图像和其大小。
  8. 关闭进度条

    pbar.close()
    
    • 完成图像缓存后,关闭进度条。

该段代码主要用于缓存图像数据,这样在后续的训练过程中可以更快速地访问这些数据,从而减少训练时的数据加载时间。它支持将图像缓存在内存中,也可以选择将图像保存到磁盘格式,以适应不同的硬件资源。通过使用多线程技术,进一步提高了性能,使得在大数据集上训练变得更加高效。

2.3 函数 check_cache_ram

  • 该函数检查是否有足够的内存来缓存图像,确保缓存不会超过可用内存的限制。
def check_cache_ram(self, safety_margin=0.1, prefix=''):
        # Check image caching requirements vs available memory
        b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
        n = min(self.n, 30)  # extrapolate from 30 random images
        for _ in range(n):
            im = cv2.imread(random.choice(self.im_files))  # sample image
            ratio = self.img_size / max(im.shape[0], im.shape[1])  # max(h, w)  # ratio
            b += im.nbytes * ratio ** 2
        mem_required = b * self.n / n  # GB required to cache dataset into RAM
        mem = psutil.virtual_memory()
        cache = mem_required * (1 + safety_margin) < mem.available  # to cache or not to cache, that is the question
        if not cache:
            LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, "
                        f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, "
                        f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
        return cache

这段代码是一个方法 check_cache_ram,其主要功能是检查在将图像数据缓存到内存中时所需的内存与可用内存之间的关系以决定是否可以在RAM中缓存数据。下面是逐步分解和详细解释代码的过程:

  1. 函数定义

    def check_cache_ram(self, safety_margin=0.1, prefix=''):
    
    • 方法 check_cache_ram 是类中的一个实例方法,接收两个参数:safety_margin(安全边际,默认值为0.1)和 prefix(用于日志输出的前缀,默认值为空字符串)。
  2. 变量初始化

    b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
    
    • b 用于累积缓存图像所需的字节数,初始为0。
    • gb 表示一个GB的字节数,即1 << 30 欧洲字节中的1GB。
  3. 样本图像数量

    n = min(self.n, 30)  # extrapolate from 30 random images
    
    • n 取类属性 self.n(总图像数量)和30之间的最小值。这样做是为了从最多30张随机图像中推断出需要的内存。
  4. 循环读取随机图像

    for _ in range(n):
        im = cv2.imread(random.choice(self.im_files))  # sample image
        ratio = self.img_size / max(im.shape[0], im.shape[1])  # max(h, w)  # ratio
        b += im.nbytes * ratio ** 2
    
    • 循环 n 次,每次随机选择一张图像文件。
    • 使用 OpenCV 读取图像 im
    • 计算 ratio 为目标图像尺寸与原图像最大边的比率。这个比率用于后续计算重新调整图像尺寸后的内存占用。
    • 累加缓存图像的字节数 b,这里使用了图像的字节数 im.nbytes 乘以 ratio 的平方,这是因为图像的尺寸调整是二维的。
  5. 计算所需内存

    mem_required = b * self.n / n  # GB required to cache dataset into RAM
    
    • 计算将整个数据集缓存到内存中所需的总内存 mem_required。这个计算基于已经计算的图像字节数 b,并按比例推算到总图像数。
  6. 获取当前可用内存

    mem = psutil.virtual_memory()
    
    • 使用 psutil 库获取系统的虚拟内存信息,包含已用、可用及总内存。
  7. 判断是否可以缓存

    cache = mem_required * (1 + safety_margin) < mem.available  # to cache or not to cache, that is the question
    
    • 检查所需内存(加上安全边际)是否小于当前可用内存,得出是否可以进行缓存的布尔值 cache
  8. 日志输出

    if not cache:
        LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, "
                    f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, "
                    f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
    
    • 如果不可以缓存,则记录所需内存、可用内存,以及当前的缓存状态(是否能够缓存图像)。
  9. 返回结果

    return cache
    
    • 返回 cache 布尔值,表示是否可以缓存图像到内存中。

此函数的主要功能是检查在将图像数据集加载到内存时所需内存的大小,并与系统可用内存进行比较,以判断是否能够成功进行内存缓存。它通过示例图像的读取和大小分析来推算整个数据集的内存需求,同时保证有一定的安全余量,以防止系统因内存不足而崩溃。如果无法缓存,则通过日志记录相关信息,保证用户及时了解内存使用情况。

2.4 函数 cache_labels

  • 该函数负责缓存标签到指定的路径,检测图像和标签的有效性,准备进行训练
  • 这个函数用于加载文件路径中的label信息生成cache文件。cache文件中包括的信息有:im_file, l, shape, segments, hash, results, msgs, version等,。
    def cache_labels(self, path=Path('./labels.cache'), prefix=''):
        # Cache dataset labels, check images and read shapes
        x = {}  # dict
        nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number missing, found, empty, corrupt, messages
        desc = f"{prefix}Scanning {path.parent / path.stem}..."
        with Pool(NUM_THREADS) as pool:
            pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
                        desc=desc,
                        total=len(self.im_files),
                        bar_format=TQDM_BAR_FORMAT)
            for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
                nm += nm_f
                nf += nf_f
                ne += ne_f
                nc += nc_f
                if im_file:
                    x[im_file] = [lb, shape, segments]
                if msg:
                    msgs.append(msg)
                pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"

        pbar.close()
        if msgs:
            LOGGER.info('\n'.join(msgs))
        if nf == 0:
            LOGGER.warning(f'{prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
        x['hash'] = get_hash(self.label_files + self.im_files)
        x['results'] = nf, nm, ne, nc, len(self.im_files)
        x['msgs'] = msgs  # warnings
        x['version'] = self.cache_version  # cache version
        try:
            np.save(path, x)  # save cache for next time
            path.with_suffix('.cache.npy').rename(path)  # remove .npy suffix
            LOGGER.info(f'{prefix}New cache created: {path}')
        except Exception as e:
            LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}')  # not writeable
        return x

以下是对所提供代码的逐步分解和详细解释:

  1. 函数定义:

    def cache_labels(self, path=Path('./labels.cache'), prefix=''):
    
    • 该函数的目的是用于缓存数据集标签。
    • path参数指定缓存文件的存储位置,默认路径为./labels.cache
    • prefix为输出信息的前缀,用于日志记录。
  2. 初始化字典和计数器:

    x = {}  # dict
    nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number missing, found, empty, corrupt, messages
    
    • x:用来存储图像的标签、形状和分段信息。
    • nm(missing):缺失标签的数量。
    • nf(found):找到的标签数量。
    • ne(empty):空标签的数量。
    • nc(corrupt):损坏标签的数量。
    • msgs:存储警告消息的列表。
  3. 开始多进程处理:

    with Pool(NUM_THREADS) as pool:
    
    • 创建一个多进程池,用于并行处理标签验证。
  4. 进度条初始化:

    pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
                desc=desc,
                total=len(self.im_files),
                bar_format=TQDM_BAR_FORMAT)
    
    • 使用tqdm库显示进度条。
    • verify_image_label是用于验证图像和标签的一项功能,zip函数将图像文件和标签文件的列表组合在一起。
  5. 结果累加和存储:

    for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
        nm += nm_f
        nf += nf_f
        ne += ne_f
        nc += nc_f
        if im_file:
            x[im_file] = [lb, shape, segments]
        if msg:
            msgs.append(msg)
        pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
    
    • 循环遍历每个文件的验证结果。
    • 更新缺失、找到、空和损坏的标签数量。
    • 如果成功找到图像文件,将其信息存储在字典x中。
    • 如果存在消息,将其存储到msgs列表中,并更新进度条描述。
  6. 关闭进度条:

    pbar.close()
    
    • 关闭进度条,表示处理完成。
  7. 记录日志:

    if msgs:
        LOGGER.info('\n'.join(msgs))
    if nf == 0:
        LOGGER.warning(f'{prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
    
    • 如果有消息,记录信息日志。
    • 如果没有找到任何标签,记录警告信息。
  8. 生成和存储缓存数据:

    x['hash'] = get_hash(self.label_files + self.im_files)
    x['results'] = nf, nm, ne, nc, len(self.im_files)
    x['msgs'] = msgs  # warnings
    x['version'] = self.cache_version  # cache version
    
    • 计算标签和图像文件的哈希值。
    • 将统计结果、消息和缓存版本信息添加到字典x中。
  9. 保存缓存到文件:

    try:
        np.save(path, x)  # save cache for next time
        path.with_suffix('.cache.npy').rename(path)  # remove .npy suffix
        LOGGER.info(f'{prefix}New cache created: {path}')
    except Exception as e:
        LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}')  # not writeable
    
    • 尝试将字典x保存到文件中,处理过程中可能出现的错误将通过日志记录。
    • 如果成功,则记录新创建的缓存文件路径。
  10. 返回缓存的数据:

return x
  • 返回包含标签信息的字典x

该代码的主要功能是缓存数据集的标签信息,包括标签是否存在、每个标签的形状、分段信息等,并生成一份统计报告,记录哪些标签被找到、缺失、空和损坏。它通过并行处理提高效率,同时使用进度条和日志记录实时反馈处理进度和结果,最终将结果保存到指定文件路径,以便于后续使用。这种机制对于深度学习模型训练前的数据准备尤为重要。

2.5 方法 __len__

  • 返回图像文件的数量。
  def __len__(self):
        return len(self.im_files)

这段代码是一个类中的方法,通常用于自定义类时实现某些特性。在这个例子中,该方法是__len__,它的主要功能是返回对象中图像文件的数量

  1. 方法定义

    def __len__(self):
    
    • 这行代码定义了一个名为__len__的方法。这个方法是特殊方法之一,Python 中的特殊方法通常以双下划线开头和结尾。__len__方法允许对象的长度可以使用内置的len()函数来获取。
  2. 返回图像文件数量

    return len(self.im_files)
    
    • self.im_files是该类的一个属性,通常是一个存储图像文件路径的列表或其他可迭代对象。
    • len(self.im_files)会计算并返回self.im_files中元素的数量,也就是说返回在该类的实例中存储的图像文件的总数。
    • return语句则把计算出的数量返回给调用者。

这段代码的主要功能是实现了一个长度计算方法,使得当对该对象调用len()函数时,能够返回对象中图像文件的数量。这通常是数据加载器类的一个通用功能,用于快速了解可以加载的图像数量,从而在训练或者推理过程中进行相应的数据处理和管理。

2.6 方法 __getitem__

  • 用于获取指定索引的图像和对应的标签。
  • 该方法支持数据增强,比如随机翻转、色彩增强与加剪切等处理。
  • 这部分是数据增强函数,一般一次性执行batch_size次。
    def __getitem__(self, index):
        index = self.indices[index]  # linear, shuffled, or image_weights

        hyp = self.hyp
        mosaic = self.mosaic and random.random() < hyp['mosaic']
        if mosaic:
            # Load mosaic
            img, labels = self.load_mosaic(index)
            shapes = None

            # MixUp augmentation
            if random.random() < hyp['mixup']:
                img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))

        else:
            # Load image
            img, (h0, w0), (h, w) = self.load_image(index)

            # Letterbox
            shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size  # final letterboxed shape
            img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
            shapes = (h0, w0), ((h / h0, w / w0), pad)  # for COCO mAP rescaling

            labels = self.labels[index].copy()
            if labels.size:  # normalized xywh to pixel xyxy format
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])

            if self.augment:
                img, labels = random_perspective(img,
                                                 labels,
                                                 degrees=hyp['degrees'],
                                                 translate=hyp['translate'],
                                                 scale=hyp['scale'],
                                                 shear=hyp['shear'],
                                                 perspective=hyp['perspective'])

        nl = len(labels)  # number of labels
        if nl:
            labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)

        if self.augment:
            # Albumentations
            img, labels = self.albumentations(img, labels)
            nl = len(labels)  # update after albumentations

            # HSV color-space
            augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])

            # Flip up-down
            if random.random() < hyp['flipud']:
                img = np.flipud(img)
                if nl:
                    labels[:, 2] = 1 - labels[:, 2]

            # Flip left-right
            if random.random() < hyp['fliplr']:
                img = np.fliplr(img)
                if nl:
                    labels[:, 1] = 1 - labels[:, 1]

            # Cutouts
            # labels = cutout(img, labels, p=0.5)
            # nl = len(labels)  # update after cutout

        labels_out = torch.zeros((nl, 6))
        if nl:
            labels_out[:, 1:] = torch.from_numpy(labels)

        # Convert
        img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
        img = np.ascontiguousarray(img)

        return torch.from_numpy(img), labels_out, self.im_files[index], shapes

这段代码是一个 Python 类中的 __getitem__ 方法,主要用于从数据集中加载图像和相应的标签。此方法通常是在训练过程中被调用,以便在每个训练步骤中获取所需的数据。以下是对代码逐步分解和详细解释:

  1. 索引处理

    index = self.indices[index]
    

    这行代码根据指定的索引从 self.indices 中获取当前索引,self.indices 可能是线性、随机打乱或者基于图像权重的索引。

  2. 超参数和拼 mosaics

    hyp = self.hyp
    mosaic = self.mosaic and random.random() < hyp['mosaic']
    

    这里加载超参数 self.hyp,并决定是否使用马赛克增强(mosaic augmentation)。通过随机数判断是否应用马赛克。

  3. 加载马赛克图像

    if mosaic:
        img, labels = self.load_mosaic(index)
        shapes = None
    

    如果选择了马赛克增强,调用 load_mosaic 方法加载马赛克图像和标签。

  4. MixUp 增强

    if random.random() < hyp['mixup']:
        img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
    

    如果条件满足,则进行 MixUp 增强,即将当前图像与另一随机图像混合。

  5. 加载单张图像

    else:
        img, (h0, w0), (h, w) = self.load_image(index)
    

    如果不使用马赛克,则加载单张图像和其尺寸。

  6. 图像调整(Letterbox)

    shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size
    img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
    

    通过 letterbox 函数调整图像大小,以适应网络的输入要求。

  7. 标签格式转换

    labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
    

    将标签从相对坐标(xywh)转换为绝对坐标(xyxy)。

  8. 数据增强

    if self.augment:
        img, labels = random_perspective(img, labels, ...)
    

    如果启用数据增强,使用 random_perspective 函数进行透视变换。

  9. 标签数量处理

    nl = len(labels)
    if nl:
        labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
    

    计算标签的数量,并将标签转换为相对坐标。

  10. 更多数据增强: 包括色彩空间增强和翻转图像等:

    if self.augment:
        img, labels = self.albumentations(img, labels)
        ...
        img = np.flipud(img)  # 上下翻转
        img = np.fliplr(img)  # 左右翻转
    
  11. 准备输出标签

    labels_out = torch.zeros((nl, 6))
    if nl:
        labels_out[:, 1:] = torch.from_numpy(labels)
    

    创建一个零填充的输出标签数组,然后将标签复制到其中。

  12. 图像数据转换

    img = img.transpose((2, 0, 1))[::-1]  # HWC 转换为 CHW
    img = np.ascontiguousarray(img)  # 确保内存连续性
    
  13. 返回结果: 最后,函数返回处理后的图像、标签及其对应的文件路径和形状信息:

    return torch.from_numpy(img), labels_out, self.im_files[index], shapes
    

__getitem__ 方法的主要功能是从数据集中根据给定的索引加载图像和标签,进行必要的数据增强,调整图像尺寸,并返回处理后的图像及标签。该方法实现了多种增强技术,包括马赛克、MixUp、随机透视、颜色空间变换等,从而增加训练数据的多样性。这对于深度学习模型的训练非常重要,能够提高模型的泛化能力。

2.7 方法 load_image

  • 加载指定索引的图像,处理图像为目标格式(BGR 转 RGB),并返回原始和调整后的图像大小。
  • 这个函数是根据图片index,从self或者从对应图片路径中载入对应index的图片

    并将原图中hw中较大者扩展到self.img_size, 较小者同比例扩展。

    会被用在LoadImagesAndLabels模块的getitem函数和load_mosaic模块中载入对应index的图片:

    def load_image(self, i):
        # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
        im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
        if im is None:  # not cached in RAM
            if fn.exists():  # load npy
                im = np.load(fn)
            else:  # read image
                im = cv2.imread(f)  # BGR
                assert im is not None, f'Image Not Found {f}'
            h0, w0 = im.shape[:2]  # orig hw
            r = self.img_size / max(h0, w0)  # ratio
            if r != 1:  # if sizes are not equal
                interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
                im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
            return im, (h0, w0), im.shape[:2]  # im, hw_original, hw_resized
        return self.ims[i], self.im_hw0[i], self.im_hw[i]  # im, hw_original, hw_resized

这段代码是一个方法 load_image 的实现,属于一个类(可能是用于加载图像数据集的类)。以下是对该方法的逐步分解和详细解释:

  1. 函数定义和参数:

    def load_image(self, i):
    
    • 定义了一个名为 load_image 的方法,接受一个参数 i,表示索引,用于指定要加载的图像在数据集中的位置。
  2. 注释说明:

    # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
    
    • 注释解释了该方法的功能:加载位于索引 i 的图像,并返回图像、原始高度与宽度、调整后的高度与宽度。
  3. 获取图像信息:

    im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
    
    • im:当前索引 i 处图像的已缓存数据(如果存在)。
    • f:当前索引 i 处图像的文件路径。
    • fn:当前索引 i 处图像的 .npy 格式文件路径,用于数字数组的存储。
  4. 检查图像是否已缓存:

    if im is None:  # not cached in RAM
    
    • 如果 im 是 None,表示图像尚未缓存到内存中。
  5. 尝试加载 .npy 文件:

    if fn.exists():  # load npy
        im = np.load(fn)
    
    • 如果 .npy 文件存在,则使用 np.load(fn) 加载图像数据。
  6. 如果 .npy 文件不存在,则读取图像:

    else:  # read image
        im = cv2.imread(f)  # BGR
        assert im is not None, f'Image Not Found {f}'
    
    • 使用 OpenCV 的 cv2.imread(f) 读取原始图片。如果加载失败,抛出异常,指出未能找到图像。
  7. 获取原始图像的高度和宽度:

    h0, w0 = im.shape[:2]  # orig hw
    
    • h0 和 w0 分别表示原始图像的高度和宽度。
  8. 计算调整比例:

    r = self.img_size / max(h0, w0)  # ratio
    
    • 计算图像调整的比例 r,使得图像的最大边长与预设的 img_size 一致。
  9. 调整图像大小(如果需要):

    if r != 1:  # if sizes are not equal
        interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
        im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
    
    • 如果调整比例不等于1,使用合适的插值方法将图像调整到指定大小。插值方法依赖于是否进行数据增强。
  10. 返回结果:

    return im, (h0, w0), im.shape[:2]  # im, hw_original, hw_resized
    
    • 返回加载的图像 im、原始大小 (h0, w0) 和调整后的大小 im.shape[:2]
  11. 如果图像已经缓存,直接返回:

    return self.ims[i], self.im_hw0[i], self.im_hw[i]  # im, hw_original, hw_resized
    
    • 如果图像已成功缓存,直接返回已缓存的图像及其尺寸信息。

该 load_image 方法的主要功能是从数据集中加载指定索引的图像。如果该图像已经缓存,则直接返回;如果没有,则尝试从 .npy 文件加载图像数据,如果不成功,则从原始图像文件读取。方法同时处理图像的尺寸调整,确保其符合预期的输入尺寸。返回的结果包括加载的图像、原始图像的尺寸以及调整后的图像尺寸。这在图像处理和深度学习训练中非常重要,因为训练过程通常要求输入图像具有一致尺寸。

 2.8 方法 cache_images_to_disk

  • cache_images_to_disk:用于将图像保存为 numpy 数组以加速加载的函数。
  • 该类中的操作大部分使用 NumPy 和 PyTorch 结合 OpenCV,充分利用多线程加速加载。
 def cache_images_to_disk(self, i):
        # Saves an image as an *.npy file for faster loading
        f = self.npy_files[i]
        if not f.exists():
            np.save(f.as_posix(), cv2.imread(self.im_files[i]))

这段代码定义了一个名为 cache_images_to_disk 的方法,属于一个类(可能是处理图像数据的类),其主要功能是将图像缓存到磁盘,以便后续快速加载。下面是对这段代码的逐步分解和详细解释:

  1. 方法定义

    def cache_images_to_disk(self, i):
    
    • 这里定义了一个实例方法 cache_images_to_disk,它接受一个参数 i,这个参数通常表示图像在列表中的索引。
  2. 注释

    # Saves an image as an *.npy file for faster loading
    
    • 注释说明了此方法的目的是将图像保存为 .npy 文件,以便更快地加载。
  3. 获取文件路径

    f = self.npy_files[i]
    
    • 从 self.npy_files 列表中获取索引为 i 的文件路径,self.npy_files 应该是一个存储 .npy 文件路径的列表。
  4. 检查文件是否存在

    if not f.exists():
    
    • 使用 exists() 方法检查文件 f 是否已经存在。如果文件存在,说明图像已经被缓存到磁盘,无需再次保存;如果不存在,则执行后续步骤。
  5. 读取和保存图像

    np.save(f.as_posix(), cv2.imread(self.im_files[i]))
    
    • 使用 cv2.imread 从 self.im_files 列表中读取索引为 i 的图像。这个操作将返回一个图像数组。
    • 然后使用 np.save 方法将读取的图像数据保存为 .npy 文件,文件的路径是 f.as_posix()as_posix() 方法确保路径格式在不同操作系统上都能正确使用。

这个 cache_images_to_disk 方法的主要功能是将特定索引的图像从文件系统读取并以 NumPy 的 .npy 格式保存到磁盘,从而缓存图像数据以提高后续的加载速度。此操作在处理大型图像数据集时非常有用,因为它可以显著减少读取图像所需的时间,提高数据加载的效率。

2.9 方法 load_mosaic

  • 把四张图像拼接成一个马赛克图,随机从数据集中选择图像并填充对应的区域,适用于增强训练数据。
  • 这个模块就是很有名的mosaic增强模块,几乎训练的时候都会用它,可以显著的提高小样本的mAP。

    代码是数据增强里面最难的, 也是最有价值的,mosaic是非常非常有用的数据增强trick, 一定要熟练掌握。

 def load_mosaic(self, index):
        # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
        labels4, segments4 = [], []
        s = self.img_size
        yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border)  # mosaic center x, y
        indices = [index] + random.choices(self.indices, k=3)  # 3 additional image indices
        random.shuffle(indices)
        for i, index in enumerate(indices):
            # Load image
            img, _, (h, w) = self.load_image(index)

            # place img in img4
            if i == 0:  # top left
                img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
                x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc  # xmin, ymin, xmax, ymax (large image)
                x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h  # xmin, ymin, xmax, ymax (small image)
            elif i == 1:  # top right
                x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
                x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
            elif i == 2:  # bottom left
                x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
                x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
            elif i == 3:  # bottom right
                x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
                x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)

            img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b]  # img4[ymin:ymax, xmin:xmax]
            padw = x1a - x1b
            padh = y1a - y1b

            # Labels
            labels, segments = self.labels[index].copy(), self.segments[index].copy()
            if labels.size:
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh)  # normalized xywh to pixel xyxy format
                segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
            labels4.append(labels)
            segments4.extend(segments)

        # Concat/clip labels
        labels4 = np.concatenate(labels4, 0)
        for x in (labels4[:, 1:], *segments4):
            np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
        # img4, labels4 = replicate(img4, labels4)  # replicate

        # Augment
        img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
        img4, labels4 = random_perspective(img4,
                                           labels4,
                                           segments4,
                                           degrees=self.hyp['degrees'],
                                           translate=self.hyp['translate'],
                                           scale=self.hyp['scale'],
                                           shear=self.hyp['shear'],
                                           perspective=self.hyp['perspective'],
                                           border=self.mosaic_border)  # border to remove

        return img4, labels4

这段代码定义了一个名为 load_mosaic 的方法,主要用于创建一个“马赛克”图像以便在训练时增强数据。下面逐步分解并详细解释这段代码。

  1. 方法定义及初始化

    def load_mosaic(self, index):
    

    该方法接收一个参数 index,这是当前图像在数据集中索引。

  2. 初始化变量

    labels4, segments4 = [], []
    s = self.img_size
    
    • labels4 和 segments4 用于存储合并后的标签和分段信息。
    • s 是图像的尺寸,用于后续创建合成图像的大小。
  3. 确定马赛克中心的坐标

    yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border)
    
    • 随机生成马赛克图像的中心坐标,mosaic_border 定义了边界。
  4. 选择要加载的图像索引

    indices = [index] + random.choices(self.indices, k=3)
    random.shuffle(indices)
    
    • indices 包含当前图像的索引和3个随机选中的索引,目的是获取总共4张图像。
    • 打乱这些索引以便随机放置图像。
  5. 加载并放置图像

    for i, index in enumerate(indices):
        img, _, (h, w) = self.load_image(index)
    

    在循环中,对于每个 index,加载对应的图像。

    • 根据索引加载图像及其原始的高度和宽度 h 和 w
  6. 计算放置坐标

    • 根据当前图像的位置(左上、右上、左下、右下),计算在马赛克图像中应放置图像的坐标:
    if i == 0:  # top left
        # 计算坐标
    elif i == 1:  # top right
        # 计算坐标
    elif i == 2:  # bottom left
        # 计算坐标
    elif i == 3:  # bottom right
        # 计算坐标
    
    • 这些坐标计算使用了 max 和 min 函数,以确保图像不会超出矩阵边界。
  7. 图像合成

    img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b]
    
    • 将当前图像按计算的坐标放置到马赛克图像中。
  8. 调整标签和分段

    labels, segments = self.labels[index].copy(), self.segments[index].copy()
    if labels.size:
        labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh)
        segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
    labels4.append(labels)
    segments4.extend(segments)
    
    • 对应于每个图像,调整其标签和分段的坐标,适应新的合成图像的位置。
    • 使用辅助函数 xywhn2xyxy 和 xyn2xy 进行坐标转换。
  9. 合并和裁剪标签

    labels4 = np.concatenate(labels4, 0)
    for x in (labels4[:, 1:], *segments4):
        np.clip(x, 0, 2 * s, out=x)
    
    • 所有标签和分段信息合并为 labels4
    • 使用 np.clip 确保标签和分段坐标不会超出图像的边缘。
  10. 数据增强

    img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
    img4, labels4 = random_perspective(img4, labels4, segments4, degrees=self.hyp['degrees'], translate=self.hyp['translate'], scale=self.hyp['scale'], shear=self.hyp['shear'], perspective=self.hyp['perspective'], border=self.mosaic_border)
    
    • 使用 copy_paste 进行随机复制粘贴增强。
    • 使用 random_perspective 增加随机透视变换效果,以不同的角度和比例对图像进行调整。
  11. 返回结果

    return img4, labels4
    
    • 返回合成好的马赛克图像和相应的标签。

load_mosaic 方法主要功能是从数据集中加载一张当前图像及三张随机图像,合成一个马赛克图像。它通过随机中心坐标、计算合适的放置位置、调整标签和进行数据增强来优化图像并增强模型训练的表现。通过这种方式,模型能够在不同的图像组合中学习到更多的特征,从而提高泛化能力。

 2.10 方法 load_mosaic9

  • 这个模块是作者的实验模块,将九张图片拼接在一张马赛克图像中。总体代码流程和load_mosaic4几乎一样,看懂了load_mosaic4再看这个就很简单了、
    def load_mosaic9(self, index):
        # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
        labels9, segments9 = [], []
        s = self.img_size
        indices = [index] + random.choices(self.indices, k=8)  # 8 additional image indices
        random.shuffle(indices)
        hp, wp = -1, -1  # height, width previous
        for i, index in enumerate(indices):
            # Load image
            img, _, (h, w) = self.load_image(index)

            # place img in img9
            if i == 0:  # center
                img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
                h0, w0 = h, w
                c = s, s, s + w, s + h  # xmin, ymin, xmax, ymax (base) coordinates
            elif i == 1:  # top
                c = s, s - h, s + w, s
            elif i == 2:  # top right
                c = s + wp, s - h, s + wp + w, s
            elif i == 3:  # right
                c = s + w0, s, s + w0 + w, s + h
            elif i == 4:  # bottom right
                c = s + w0, s + hp, s + w0 + w, s + hp + h
            elif i == 5:  # bottom
                c = s + w0 - w, s + h0, s + w0, s + h0 + h
            elif i == 6:  # bottom left
                c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
            elif i == 7:  # left
                c = s - w, s + h0 - h, s, s + h0
            elif i == 8:  # top left
                c = s - w, s + h0 - hp - h, s, s + h0 - hp

            padx, pady = c[:2]
            x1, y1, x2, y2 = (max(x, 0) for x in c)  # allocate coords

            # Labels
            labels, segments = self.labels[index].copy(), self.segments[index].copy()
            if labels.size:
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady)  # normalized xywh to pixel xyxy format
                segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
            labels9.append(labels)
            segments9.extend(segments)

            # Image
            img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:]  # img9[ymin:ymax, xmin:xmax]
            hp, wp = h, w  # height, width previous

        # Offset
        yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border)  # mosaic center x, y
        img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]

        # Concat/clip labels
        labels9 = np.concatenate(labels9, 0)
        labels9[:, [1, 3]] -= xc
        labels9[:, [2, 4]] -= yc
        c = np.array([xc, yc])  # centers
        segments9 = [x - c for x in segments9]

        for x in (labels9[:, 1:], *segments9):
            np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
        # img9, labels9 = replicate(img9, labels9)  # replicate

        # Augment
        img9, labels9, segments9 = copy_paste(img9, labels9, segments9, p=self.hyp['copy_paste'])
        img9, labels9 = random_perspective(img9,
                                           labels9,
                                           segments9,
                                           degrees=self.hyp['degrees'],
                                           translate=self.hyp['translate'],
                                           scale=self.hyp['scale'],
                                           shear=self.hyp['shear'],
                                           perspective=self.hyp['perspective'],
                                           border=self.mosaic_border)  # border to remove

        return img9, labels9

该函数 load_mosaic9 用于加载一个包含1张主要图像和8张随机图像的9图拼接(mosaic)图像

  1. 函数定义与初始化:

    def load_mosaic9(self, index):
        labels9, segments9 = [], []
        s = self.img_size
        indices = [index] + random.choices(self.indices, k=8)  # 8 additional image indices
        random.shuffle(indices)
    
    • labels9 和 segments9 是用来存储拼接图像对应的标签和分段信息。
    • s 是图像的目标尺寸。
    • indices 是一个包含所需加载图像的索引列表,包含当前图像的索引和随机选择的8个其他图像的索引。
  2. 初始化高度和宽度:

    hp, wp = -1, -1  # height, width previous
    
    • hp 和 wp 用于跟踪上一个图像的高度和宽度。
  3. 加载图像及拼接:

    for i, index in enumerate(indices):
        img, _, (h, w) = self.load_image(index)
    
    • 循环遍历各个索引以加载图像,获取高 h 和宽 w
  4. 在拼接图像中放置各个图像:

    if i == 0:  # center
        img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
        h0, w0 = h, w
        c = s, s, s + w, s + h  # xmin, ymin, xmax, ymax (base) coordinates
    
    • 第0张图像放置在拼接图像的中心。
    • 其他图像(第1到第8张)根据不同的位置放置。在每种情况下,定义图像应该被放置的坐标 c
  5. 计算图像位置:

    padx, pady = c[:2]
    x1, y1, x2, y2 = (max(x, 0) for x in c)  # allocate coords
    
    • 计算如何将加载的图像放置到拼接图像的正确位置。
  6. 处理标签与分段信息:

    labels, segments = self.labels[index].copy(), self.segments[index].copy()
    if labels.size:
        labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady)  # normalized xywh to pixel xyxy format
        segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
    labels9.append(labels)
    segments9.extend(segments)
    
    • 复制当前图像的标签和分段信息,并将其规模转换为拼接图像的坐标系统。将这些信息附加到 labels9 和 segments9 列表中。
  7. 更新拼接图像:

    img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:]  # img9[ymin:ymax, xmin:xmax]
    hp, wp = h, w  # height, width previous
    
    • 将当前图像的部分内容放入拼接图像的适当位置,并更新记录上一个图像的尺寸。
  8. 偏移拼接图像:

    yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border)  # mosaic center x, y
    img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
    
    • 随机生成拼接图像的中心偏移量。
  9. 合并与裁剪标签:

    labels9 = np.concatenate(labels9, 0)
    labels9[:, [1, 3]] -= xc
    labels9[:, [2, 4]] -= yc
    c = np.array([xc, yc])  # centers
    segments9 = [x - c for x in segments9]
    
    • 合并所有图像的标签与分段信息,并更新其坐标使之与拼接后的图像相符。
  10. 限制坐标范围:

    for x in (labels9[:, 1:], *segments9):
        np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
    
    • 将标签和分段坐标限制在合法的范围内,以处理映射过程中的潜在超出图像边界的问题。
  11. 数据增强处理:

    img9, labels9, segments9 = copy_paste(img9, labels9, segments9, p=self.hyp['copy_paste'])
    img9, labels9 = random_perspective(img9,
                                       labels9,
                                       segments9,
                                       degrees=self.hyp['degrees'],
                                       translate=self.hyp['translate'],
                                       scale=self.hyp['scale'],
                                       shear=self.hyp['shear'],
                                       perspective=self.hyp['perspective'],
                                       border=self.mosaic_border)  # border to remove
    
    • 对拼接图像及其标签与分段信息应用数据增强,例如随机透视、程度调整等操作。
  12. 返回结果:

    return img9, labels9
    
    • 返回拼接后的图像和相应的标签信息。

此代码的主要功能是生成一个包含1张主图像和8张随机图像的9图拼接(mosaic)图像,适用于YOLOv5模型的数据增强过程。通过这种方式,模型可以从多个视角学习到更多的信息,从而提高其泛化能力。拼接后的图像连同标签信息也随之被处理,以确保目标检测任务中对物体位置的准确性。该方法增强了训练数据的多样性,提高了模型的鲁棒性。

2.11 方法 collate_fn

  • 处理批量数据,合并成张量以便于训练使用。
 def collate_fn(batch):
        im, label, path, shapes = zip(*batch)  # transposed
        for i, lb in enumerate(label):
            lb[:, 0] = i  # add target image index for build_targets()
        return torch.stack(im, 0), torch.cat(label, 0), path, shapes

这段代码定义了一个名为 collate_fn 的函数,主要用于数据加载过程中对一批数据进行整理和处理。下面是对代码逐步分解和详细解释:

  1. 函数定义

    def collate_fn(batch):
    
    • 定义一个名为 collate_fn 的函数,该函数接收一个参数 batch,它是一个包含多个样本的列表。
  2. 解包数据

    im, label, path, shapes = zip(*batch)  # transposed
    
    • zip(*batch) 将 batch 中的多个样本解包为相应的四个部分:im(图像),label(标签),path(图像路径),shapes(图像的原始形状)。这里的 *batch 意思是将 batch 中的每个元素(即单个样本)提取出来,形成一个新的元组集合。
    • 例如,如果 batch 中的每个样本包含 (im, label, path, shapes),那么 zip 函数将分别收集所有的 imlabelpath 和 shapes
  3. 添加目标图像索引

    for i, lb in enumerate(label):
        lb[:, 0] = i  # add target image index for build_targets()
    
    • 这段代码是一个循环,遍历每个标签 lbenumerate(label) 同时返回当前标签的索引 i 和标签本身 lb
    • lb[:, 0] = i 将每个标签的第一个元素(通常代表类的索引或目标的类型)设置为当前图像的索引 i。这一步骤的目的是为后续的目标构建提供图像索引。
  4. 合并并返回结果

    return torch.stack(im, 0), torch.cat(label, 0), path, shapes
    
    • torch.stack(im, 0) 将图像列表 im 沿着第一个维度合并成一个单一的张量,通常用于批处理。
    • torch.cat(label, 0) 将标签 label 列表合并为一个张量,通常也是按第一个维度连接。
    • 最后,函数返回完整的图像张量、合并的标签张量、图像路径列表及其原始形状信息。

这段代码的主要功能是对一批加载的数据(图像及其相关信息)进行整理,以便在后续处理阶段(如模型训练或评估)能够便利地访问每个图像及其对应的标签。它将图像和标签重塑为适合批处理的格式,并在标签中包含图像的索引,这对于目标检测和任务训练时的目标构建至关重要。

2.12 方法 collate_fn4 

    def collate_fn4(batch):
        im, label, path, shapes = zip(*batch)  # transposed
        n = len(shapes) // 4
        im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]

        ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
        wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
        s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]])  # scale
        for i in range(n):  # zidane torch.zeros(16,3,720,1280)  # BCHW
            i *= 4
            if random.random() < 0.5:
                im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
                                    align_corners=False)[0].type(im[i].type())
                lb = label[i]
            else:
                im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
                lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
            im4.append(im1)
            label4.append(lb)

        for i, lb in enumerate(label4):
            lb[:, 0] = i  # add target image index for build_targets()

        return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4

这段代码是一个自定义的 collate_fn4 函数,主要用于处理一个批次的图像和标签数据,并为后续的模型训练准备格式化的输入数据。下面逐步分解并详细解释这段代码中的每一部分:

  1. 函数定义和参数:

    def collate_fn4(batch):
    

    该函数定义了一个名为 collate_fn4 的函数,其参数 batch 是一个由多个样本组成的列表,每个样本通常包含图像、标签、路径和形状等信息。

  2. 拆解样本:

    im, label, path, shapes = zip(*batch)  # transposed
    

    zip(*batch) 将 batch 列表中的每个样本拆解开来,分离出图像(im)、标签(label)、路径(path)和形状(shapes)。最终,它们会被组合为四个独立的元组。

  3. 获取样本数量:

    n = len(shapes) // 4
    

    从 shapes 中计算批次中的样本数量 n,这里假设每四个样本组合成一个新的样本。

  4. 初始化新的列表:

    im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
    

    初始化 im4 和 label4 列表,用于存放合成后的图像和标签,path4 和 shapes4 保持原样,从前 n 个样本中提取路径和形状。

  5. 定义平移张量:

    ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
    wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
    s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]])  # scale
    

    创建了平移向量和缩放因子的张量,这对图像标签的位置调整和缩放很重要。

  6. 循环合成图像和标签:

    for i in range(n):  # zidane torch.zeros(16,3,720,1280)  # BCHW
        i *= 4
    

    遍历前 n 个样本,在每一步中,i 被乘以4,用来处理每四个样本:

    • 随机选择处理方式:

      if random.random() < 0.5:
          im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
                              align_corners=False)[0].type(im[i].type())
          lb = label[i]
      

      以50%的概率,选择将当前图像 im[i] 放大2倍,并保持标签不变。

      • 拼接图像和标签:
      else:
          im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
          lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
      

      否则,将四个图像在宽度上拼接在一起,形成一个复合图像 im1,并且相应地调整标签的位置。

  7. 更新标签索引:

    for i, lb in enumerate(label4):
        lb[:, 0] = i  # add target image index for build_targets()
    

    更新每个标签的第一列,表示该标签对应的目标图像索引。

  8. 返回结果:

    return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4
    

    最后,该函数以张量形式返回四个输出:

    • 合并后的图像 im4
    • 合并后的标签 label4
    • 图像路径 path4
    • 图像形状 shapes4

collate_fn4 函数的主要功能是在处理批量数据时,为深度学习模型准备格式化的输入。它合成多个图像和对应的标签,同时通过随机算法增强数据(如图像放大和拼接),从而提升模型的训练效果。最终,它返回处理后的图像、标签、路径和形状,用于后续训练过程。

 三、【分类】创建数据加载器(dataloader)

def create_classification_dataloader(path,
                                     imgsz=224,
                                     batch_size=16,
                                     augment=True,
                                     cache=False,
                                     rank=-1,
                                     workers=8,
                                     shuffle=True):
    # Returns Dataloader object to be used with YOLOv5 Classifier
    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
        dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
    batch_size = min(batch_size, len(dataset))
    nd = torch.cuda.device_count()
    nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
    sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
    generator = torch.Generator()
    generator.manual_seed(6148914691236517205 + RANK)
    return InfiniteDataLoader(dataset,
                              batch_size=batch_size,
                              shuffle=shuffle and sampler is None,
                              num_workers=nw,
                              sampler=sampler,
                              pin_memory=PIN_MEMORY,
                              worker_init_fn=seed_worker,
                              generator=generator)  # or DataLoader(persistent_workers=True)

这段代码的功能是创建一个用于分类的PyTorch数据加载器(DataLoader),可以与YOLOv5分类器配合使用。下面是逐步分解和详细解释代码的过程:

  1. 函数定义

    def create_classification_dataloader(path, imgsz=224, batch_size=16, augment=True, cache=False, rank=-1, workers=8, shuffle=True):
    
    • path: 数据集的根目录路径,通常包含图像和标签。
    • imgsz: 输入图像的大小,默认为224。
    • batch_size: 每个批次的数据量,默认为16。
    • augment: 是否使用数据增强,默认为True。
    • cache: 是否使用缓存,默认为False。
    • rank: 在分布式训练中的进程rank,默认为-1(不启用分布式)。
    • workers: 数据加载的工作线程数,默认为8。
    • shuffle: 是否在每个epoch开始时打乱数据,默认为True。
  2. 分布式数据加载的初始化

    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
    

    使用torch_distributed_zero_first(rank)上下文管理器。如果启用了分布式训练(DDP),则只在rank为0时初始化数据集,以避免重复初始化。

  3. 创建数据集

    dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
    

    利用提供的参数创建一个ClassificationDataset实例,这个数据集会加载指定路径下的图像,并进行相应的处理。

  4. 调整批量大小

    batch_size = min(batch_size, len(dataset))
    

    确保批量大小不超过数据集的大小。

  5. 计算设备数量和线程数量

    nd = torch.cuda.device_count()
    nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
    
    • nd获取当前可用的GPU数量。
    • nw计算可用的工作线程数,确保不超过CPU核心数、批量大小和指定的最大工作线程数。
  6. 选择采样器

    sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
    

    如果rank为-1,说明不是分布式环境,则sampler为None;否则,使用DistributedSampler来处理在分布式训练下的数据采样。

  7. 初始化随机种子

    generator = torch.Generator()
    generator.manual_seed(6148914691236517205 + RANK)
    

    创建一个随机数生成器,并设置种子,以确保不同进程的随机性一致。

  8. 返回数据加载器

    return InfiniteDataLoader(dataset, batch_size=batch_size, shuffle=shuffle and sampler is None, num_workers=nw, sampler=sampler, pin_memory=PIN_MEMORY, worker_init_fn=seed_worker, generator=generator)
    

    初始化并返回一个InfiniteDataLoader,这是一个可以连续提供数据的加载器,适用于训练中的无穷循环。设置批量大小、打乱状态、工作线程数、采样器、内存固定、工作线程初始化函数和随机生成器。

该函数的主要功能是创建一个用于YOLOv5分类器的数据加载器,支持数据增强、缓存、分布式训练和多线程数据加载。通过合理配置函数参数,用户可以控制数据加载器的行为,以适应不同的训练需求。此数据加载器为训练过程提供高效的数据访问,确保训练效率最大化。

 四、

class ClassificationDataset(torchvision.datasets.ImageFolder):
    """
    YOLOv5 Classification Dataset.
    Arguments
        root:  Dataset path
        transform:  torchvision transforms, used by default
        album_transform: Albumentations transforms, used if installed
    """

    def __init__(self, root, augment, imgsz, cache=False):
        super().__init__(root=root)
        self.torch_transforms = classify_transforms(imgsz)
        self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
        self.cache_ram = cache is True or cache == 'ram'
        self.cache_disk = cache == 'disk'
        self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples]  # file, index, npy, im

    def __getitem__(self, i):
        f, j, fn, im = self.samples[i]  # filename, index, filename.with_suffix('.npy'), image
        if self.cache_ram and im is None:
            im = self.samples[i][3] = cv2.imread(f)
        elif self.cache_disk:
            if not fn.exists():  # load npy
                np.save(fn.as_posix(), cv2.imread(f))
            im = np.load(fn)
        else:  # read image
            im = cv2.imread(f)  # BGR
        if self.album_transforms:
            sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
        else:
            sample = self.torch_transforms(im)
        return sample, j

这段代码定义了一个名为 ClassificationDataset 的类,继承自 torchvision.datasets.ImageFolder,用于处理图像分类任务。在 YOLOv5 框架中,该类提供了一种灵活的方式来加载和预处理图像数据,包括数据增强和缓存功能。以下是对代码的逐步分解和详细解释:

3.1 类定义和文档字符串

class ClassificationDataset(torchvision.datasets.ImageFolder):
"""
YOLOv5 Classification Dataset.
Arguments
    root:  Dataset path
    transform:  torchvision transforms, used by default
    album_transform: Albumentations transforms, used if installed
"""

这个类是 YOLOv5 分类数据集的实现,继承自 ImageFolder。文档字符串描述了类的功能和参数。

3.2 方法 __init__

    def __init__(self, root, augment, imgsz, cache=False):
        super().__init__(root=root)
        self.torch_transforms = classify_transforms(imgsz)
        self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
        self.cache_ram = cache is True or cache == 'ram'
        self.cache_disk = cache == 'disk'
        self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples]  # file, index, npy, im

这段代码的目的是初始化一个图像分类数据集实例。它设置了图像的根目录,定义了图像转换和数据增强配置,设置了缓存策略(将图像缓存在内存或磁盘),并准备了样本列表,以存储图像文件的路径及其对应的缓存路径。整体来看,这段代码为图像分类任务的数据加载与预处理提供了基础设施。

  • 这段代码是一个类的初始化方法(__init__),用来初始化一个图像分类数据集对象。以下是对代码逐步分解和详细解释:

  • 方法定义

    def __init__(self, root, augment, imgsz, cache=False):
    
    • __init__是一个特殊的方法,用于构造类的实例。
    • self:表示实例本身。
    • root:表示数据集的根目录(存放图像的文件夹)。
    • augment:表示是否应用数据增强。
    • imgsz:表示输入图像的尺寸。
    • cache:可选参数,用于指定是否缓存图像(默认为False)。
  • 调用父类初始化

    super().__init__(root=root)
    

    调用父类(在这里是torchvision.datasets.ImageFolder)的初始化方法,传递根目录root

  • 定义图像转换

    self.torch_transforms = classify_transforms(imgsz)
    

    classify_transforms(imgsz):调用一个函数生成图像变换,通常包括图像的缩放、裁剪、归一化等,并且指定输出的图像大小为imgsz

  • 条件应用数据增强

    self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
    
    • 如果augment为真,调用classify_albumentations生成增强变换,传递augmentimgsz
    • 如果augment为假,则self.album_transformsNone
  • 设置缓存策略

    self.cache_ram = cache is True or cache == 'ram'
    self.cache_disk = cache == 'disk'
    
    • self.cache_ram:如果cacheTrue或为'ram',则将其设置为True,表示将图像缓存到内存中。
    • self.cache_disk:如果cache'disk',则将其设置为True,表示将图像缓存到磁盘中。
  • 准备样本列表

    self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples]
    
    • 这段代码是对self.samples进行处理,将每个样本的元组转换为列表,并为每个样本添加两个新元素:
      • Path(x[0]).with_suffix('.npy'):将样本的文件路径转换为.npy文件路径,方便后续可能的图像缓存。
      • None:占位符,可能用于存储图像的内部表示。

3.3 方法__getitem__

    def __getitem__(self, i):
        f, j, fn, im = self.samples[i]  # filename, index, filename.with_suffix('.npy'), image
        if self.cache_ram and im is None:
            im = self.samples[i][3] = cv2.imread(f)
        elif self.cache_disk:
            if not fn.exists():  # load npy
                np.save(fn.as_posix(), cv2.imread(f))
            im = np.load(fn)
        else:  # read image
            im = cv2.imread(f)  # BGR
        if self.album_transforms:
            sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
        else:
            sample = self.torch_transforms(im)
        return sample, j

这段代码是一个类中的 __getitem__ 方法,目的是根据索引 i 获取数据集中的一个样本。以下是对这段代码的逐步分解与详细解释:

  1. 参数

    • i:传入的索引,用于访问数据集中的某一项。
  2. 获取样本

    f, j, fn, im = self.samples[i]
    

    这一行从 self.samples 中获取第 i 个样本。这个样本是一个包含多个元素的元组,通常包括:

    • f:文件名(图像文件的路径)。
    • j:索引(可能用于表示类别或标签)。
    • fn:带有 .npy 后缀的文件名(用于缓存图像的 numpy 数组)。
    • im:图像数据信息(可能为 None,表示需要读取文件)。
  3. 图像缓存机制

    • 内存缓存
    if self.cache_ram and im is None:
        im = self.samples[i][3] = cv2.imread(f)
    

    如果启用了内存缓存(self.cache_ram 为真),并且 im 是 None,则使用 OpenCV 读取图像文件(cv2.imread(f)),并将其存储到 self.samples[i][3] 中,以便后续访问。

    • 磁盘缓存
    elif self.cache_disk:
        if not fn.exists():  # load npy
            np.save(fn.as_posix(), cv2.imread(f))
        im = np.load(fn)
    

    如果启用了磁盘缓存(self.cache_disk 为真),首先检查缓存文件(.npy 文件)是否存在。如果不存在,则将读取的图像数据保存为 .npy 格式的文件。然后,无论如何,都加载这个 .npy 文件(np.load(fn))赋值给 im

    • 直接读取图像
    else:  # read image
        im = cv2.imread(f)  # BGR
    

    如果以上两种缓存机制都不启用,则直接读取图像文件并将其赋值给 im

  4. 图像转换

    if self.album_transforms:
        sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
    else:
        sample = self.torch_transforms(im)
    

    最后,根据是否使用 album_transforms 进行图像的变换处理:

    • 如果启用了阿尔布门图像转换 (album_transforms),则首先将 BGR 格式的图像转换为 RGB 格式,并应用变换。
    • 否则,使用 PyTorch 的转换函数进行处理。
  5. 返回结果

    return sample, j
    

    返回处理后的图像(sample)和索引或标签(j)。

这个 __getitem__ 方法实现了数据集的索引访问功能,从给定的样本列表中读取指定的图像,并能够根据需求从内存或磁盘中缓存图像数据。整个过程支持图像的读入、缓存机制(内存及磁盘)以及图像变换,旨在高效获取训练所需的图像数据。最终,它返回经过必要处理的图像和对应的索引,供后续模型训练使用。

3.4 总结

ClassificationDataset 类实现了一个用于图像分类的自定义数据集,提供了灵活的图像加载和预处理功能。它支持数据增强,能够根据用户指定的选项将数据缓存到 RAM 或磁盘中,从而提高数据加载的效率。这个类适用于需要进行图像分类的场景,并与 YOLOv5 框架兼容,有助于简化训练过程中的数据准备工作。

四、自定义DataLoader 

class InfiniteDataLoader(dataloader.DataLoader):
    """ Dataloader that reuses workers

    Uses same syntax as vanilla DataLoader
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
        self.iterator = super().__iter__()

    def __len__(self):
        return len(self.batch_sampler.sampler)

    def __iter__(self):
        for _ in range(len(self)):
            yield next(self.iterator)


class _RepeatSampler:
    """ Sampler that repeats forever

    Args:
        sampler (Sampler)
    """

    def __init__(self, sampler):
        self.sampler = sampler


    def __iter__(self):
        while True:
            yield from iter(self.sampler)

这段代码定义了两个类:InfiniteDataLoader 和 _RepeatSampler它们的目的是创建一个可以无限迭代的数据加载器,允许在训练过程中重复使用数据,进行持续性采样,而不需要在每个epoch开始时重新加载数据。

4.1. InfiniteDataLoader 类

  • 继承自 dataloader.DataLoader

    • InfiniteDataLoader 是 DataLoader 的子类,旨在扩展其功能,使得数据加载器可以无限次迭代。
  • 构造函数 __init__

    • super().__init__(*args, **kwargs):调用父类的构造函数进行初始化,以确保 DataLoader 的所有功能均被保留。
    • object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)):将 self.batch_sampler 属性设置为 _RepeatSampler 类的实例。这允许使用 _RepeatSampler 来创建无限迭代的逻辑。
    • self.iterator = super().__iter__():创建一个迭代器,调用父类的迭代方法。
  • __len__ 方法

    • 返回当前批次采样器的长度 (len(self.batch_sampler.sampler))。这表示可以迭代的批次数量。
  • __iter__ 方法

    • 使用一个循环来无限地返回迭代器的下一个值,通过调用 next(self.iterator) 以持续提供数据。

4.2. _RepeatSampler 类

  • 构造函数 __init__

    • 接收一个 sampler 参数,并将其存储为实例属性 self.sampler。这个 sampler 负责提供数据样本。
  • __iter__ 方法

    • 实现了一个无限循环一直返回 self.sampler 中的元素。这意味着它会不断重复迭代 self.sampler,提供样本,而不会停止。

这段代码实现了一个 InfiniteDataLoader,它基于 PyTorch 的 DataLoader 通过重新使用工作进程,实现了一个无限循环的数据加载机制使用 _RepeatSampler,在每个epoch中可以持续不断地从数据集中采样,而不必在每个epoch开始时重新加载数据。这主要用于需要对数据进行多次迭代或训练时,避免了频繁的数据加载,能够提高训练效率。

五、LoadImages

class LoadImages:
    # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
    def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
        files = []
        for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
            p = str(Path(p).resolve())
            if '*' in p:
                files.extend(sorted(glob.glob(p, recursive=True)))  # glob
            elif os.path.isdir(p):
                files.extend(sorted(glob.glob(os.path.join(p, '*.*'))))  # dir
            elif os.path.isfile(p):
                files.append(p)  # files
            else:
                raise FileNotFoundError(f'{p} does not exist')

        images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
        videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
        ni, nv = len(images), len(videos)

        self.img_size = img_size
        self.stride = stride
        self.files = images + videos
        self.nf = ni + nv  # number of files
        self.video_flag = [False] * ni + [True] * nv
        self.mode = 'image'
        self.auto = auto
        self.transforms = transforms  # optional
        self.vid_stride = vid_stride  # video frame-rate stride
        if any(videos):
            self._new_video(videos[0])  # new video
        else:
            self.cap = None
        assert self.nf > 0, f'No images or videos found in {p}. ' \
                            f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'

    def __iter__(self):
        self.count = 0
        return self

    def __next__(self):
        if self.count == self.nf:
            raise StopIteration
        path = self.files[self.count]

        if self.video_flag[self.count]:
            # Read video
            self.mode = 'video'
            for _ in range(self.vid_stride):
                self.cap.grab()
            ret_val, im0 = self.cap.retrieve()
            while not ret_val:
                self.count += 1
                self.cap.release()
                if self.count == self.nf:  # last video
                    raise StopIteration
                path = self.files[self.count]
                self._new_video(path)
                ret_val, im0 = self.cap.read()

            self.frame += 1
            # im0 = self._cv2_rotate(im0)  # for use if cv2 autorotation is False
            s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '

        else:
            # Read image
            self.count += 1
            im0 = cv2.imread(path)  # BGR
            assert im0 is not None, f'Image Not Found {path}'
            s = f'image {self.count}/{self.nf} {path}: '

        if self.transforms:
            im = self.transforms(im0)  # transforms
        else:
            im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0]  # padded resize
            im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
            im = np.ascontiguousarray(im)  # contiguous

        return path, im, im0, self.cap, s

    def _new_video(self, path):
        # Create a new video capture object
        self.frame = 0
        self.cap = cv2.VideoCapture(path)
        self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
        self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META))  # rotation degrees
        # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0)  # disable https://github.com/ultralytics/yolov5/issues/8493

    def _cv2_rotate(self, im):
        # Rotate a cv2 video manually
        if self.orientation == 0:
            return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
        elif self.orientation == 180:
            return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
        elif self.orientation == 90:
            return cv2.rotate(im, cv2.ROTATE_180)
        return im

    def __len__(self):
        return self.nf  # number of files

5.1 类LoadImages 定义

类 LoadImages 该类是 YOLOv5 中的图像/视频数据加载器,主要负责从指定路径加载图像或视频文件,并进行预处理以供训练或推理使用。

5.2 __init__

    def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
        files = []
        for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
            p = str(Path(p).resolve())
            if '*' in p:
                files.extend(sorted(glob.glob(p, recursive=True)))  # glob
            elif os.path.isdir(p):
                files.extend(sorted(glob.glob(os.path.join(p, '*.*'))))  # dir
            elif os.path.isfile(p):
                files.append(p)  # files
            else:
                raise FileNotFoundError(f'{p} does not exist')

        images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
        videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
        ni, nv = len(images), len(videos)

        self.img_size = img_size
        self.stride = stride
        self.files = images + videos
        self.nf = ni + nv  # number of files
        self.video_flag = [False] * ni + [True] * nv
        self.mode = 'image'
        self.auto = auto
        self.transforms = transforms  # optional
        self.vid_stride = vid_stride  # video frame-rate stride
        if any(videos):
            self._new_video(videos[0])  # new video
        else:
            self.cap = None
        assert self.nf > 0, f'No images or videos found in {p}. ' \
                            f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'

下面对给定的代码逐步分解并详细解释:

  • 初始化方法定义

    def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
    

    这是一个类的初始化方法,用于创建一个新实例。它接受多个参数:

    • path: 数据路径,可以是图像或视频文件的路径。
    • img_size: 设置图像或视频的大小,默认值为640。
    • stride: 步幅,默认值为32。
    • auto: 一个布尔值,表示是否自动调整图像大小,默认值为True。
    • transforms: 可选的图像变换。
    • vid_stride: 视频帧间隔,默认值为1。
  • 文件列表初始化

    files = []
    
  • 路径处理

    for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
        p = str(Path(p).resolve())
    
    • 判断path是否为列表或元组,如果是,则对其中的路径进行排序;否则,将其转为列表并处理。
    • 使用Path(p).resolve()将路径转换为绝对路径。
  • 文件查找

    if '*' in p:
        files.extend(sorted(glob.glob(p, recursive=True)))  # glob
    elif os.path.isdir(p):
        files.extend(sorted(glob.glob(os.path.join(p, '*.*'))))  # dir
    elif os.path.isfile(p):
        files.append(p)  # files
    else:
        raise FileNotFoundError(f'{p} does not exist')
    
    • 如果路径中包含'*',使用glob模块搜索匹配的文件。
    • 如果路径是一个目录,获取该目录下所有文件。
    • 如果路径是一个文件,将该文件添加到列表。
    • 如果路径无效,抛出相应的错误。
  • 分类文件

    images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
    videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
    
    • files中筛选出图像文件和视频文件。
    • IMG_FORMATSVID_FORMATS 应该是预定义的图像和视频格式集合。
  • 文件数量及标志设置

    ni, nv = len(images), len(videos)
    
    self.img_size = img_size
    self.stride = stride
    self.files = images + videos
    self.nf = ni + nv  # number of files
    self.video_flag = [False] * ni + [True] * nv
    self.mode = 'image'
    self.auto = auto
    self.transforms = transforms  # optional
    self.vid_stride = vid_stride  # video frame-rate stride
    
    • 计算图像和视频的数量。
    • 将一些重要的参数和属性初始化存储在self对象中,以供类的其他方法使用。
  • 视频初始化

    if any(videos):
        self._new_video(videos[0])  # new video
    else:
        self.cap = None
    
    • 检查是否有视频文件,如果有,调用_new_video方法初始化第一个视频。
    • 如果没有视频,将self.cap设置为None
  • 异常处理

    assert self.nf > 0, f'No images or videos found in {p}. ' \
                        f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
    
    • 确保至少找到一个图像或视频文件,如果没有找到,抛出断言错误并返回提示。

这段代码的主要功能是初始化一个数据加载器,用于加载图像和视频文件。它的关键步骤包括:

  • 接收文件路径并解析这些路径中的每一个文件。
  • 分类和过滤出有效的图像和视频文件。
  • 计算总的文件数量并初始化一些属性。
  • 检查是否找到文件,以便准备后续的数据处理工作。

4.3 __iter__(self)

  • 功能
    • 初始化计数器并返回自身,允许该类的实例作为迭代器使用。
    def __iter__(self):
        self.count = 0
        return self

这段代码是一个类的方法,通常用于实现迭代器协议。下面是逐步分解和详细解释:

  • def __iter__(self)::

    • 这是一个类中的实例方法,名为 __iter__。在Python中,任何类只要实现了这个方法,就可以被迭代(例如,可以使用for循环进行遍历)。
  • self.count = 0:

    • 这里,我们给实例变量 count 赋值为0。这个变量通常用来跟踪当前迭代到的位置(或索引)。在每次迭代开始时,将其重置为0,确保从头开始迭代。
  • return self:

    • 将当前对象 self 返回。这使得对象可以作为迭代器直接使用。根据Python的迭代协议,返回一个实现了 __next__ 方法的对象也是有效的。

这段代码定义了一个迭代器的开头部分。具体来说,它实现了迭代器协议中的 __iter__ 方法,使得类的实例可以被迭代。每次开始迭代时,count 计数器都会被重置为0,确保迭代从头开始。返回 self 意味着这个类的实例本身将被视为一个迭代器,能够与Python的循环结构(如for循环)直接兼容。

主要功能

  • 迭代器接口: 这段代码是迭代器协议的一部分,允许对象支持迭代功能,能够被for循环等语法迭代访问。

  • 重置计数器: 每当迭代开始时,count计数器被重置,确保每次迭代都从最初的状态开始。

5.3 __next__(self)

  • 功能
    • 加载下一个图像或视频帧:
      • 如果当前路径是视频,使用 cv2.VideoCapture 读取视频帧。
      • 如果是图像,使用 OpenCV 的 cv2.imread()读取图像文件。
      • 进行必要的预处理,返回路径、处理过的图像、小图像和其他信息。
    def __next__(self):
        if self.count == self.nf:
            raise StopIteration
        path = self.files[self.count]

        if self.video_flag[self.count]:
            # Read video
            self.mode = 'video'
            for _ in range(self.vid_stride):
                self.cap.grab()
            ret_val, im0 = self.cap.retrieve()
            while not ret_val:
                self.count += 1
                self.cap.release()
                if self.count == self.nf:  # last video
                    raise StopIteration
                path = self.files[self.count]
                self._new_video(path)
                ret_val, im0 = self.cap.read()

            self.frame += 1
            # im0 = self._cv2_rotate(im0)  # for use if cv2 autorotation is False
            s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '

        else:
            # Read image
            self.count += 1
            im0 = cv2.imread(path)  # BGR
            assert im0 is not None, f'Image Not Found {path}'
            s = f'image {self.count}/{self.nf} {path}: '

        if self.transforms:
            im = self.transforms(im0)  # transforms
        else:
            im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0]  # padded resize
            im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
            im = np.ascontiguousarray(im)  # contiguous

        return path, im, im0, self.cap, s

这段代码定义了一个名为 __next__ 的方法,通常用于实现迭代器,使得对象可以使用 next() 函数获取下一个元素。以下是对该代码的逐步分解和详细解释:

  1. 判断是否结束迭代:

    if self.count == self.nf:
        raise StopIteration
    

    此处检查 self.count 是否等于 self.nf(文件总数)。如果相等,表明已经读取完所有的文件,抛出 StopIteration 异常以结束迭代。

  2. 获取当前文件路径:

    path = self.files[self.count]
    

    根据当前计数 self.count 从 self.files 列表中获取当前文件的路径。

  3. 检查是否为视频文件:

    if self.video_flag[self.count]:
    

    通过 self.video_flag 列表检查当前文件是否为视频文件。

  4. 读取视频文件:

    • 如果是视频,首先更新模式为 'video':
      self.mode = 'video'
      
    • 进行抓取视频帧的操作以跳过指定帧 (self.vid_stride):
      for _ in range(self.vid_stride):
          self.cap.grab()
      
    • 从视频捕获对象 self.cap 中检索一帧图像:
      ret_val, im0 = self.cap.retrieve()
      
    • 如果未成功获取帧,进入循环寻找下一个视频:
      while not ret_val:
          self.count += 1
          self.cap.release()
          if self.count == self.nf:  # last video
              raise StopIteration
          path = self.files[self.count]
          self._new_video(path)
          ret_val, im0 = self.cap.read()
      
    • 更新帧计数,处理成功的帧并生成状态信息字符串:
      self.frame += 1
      s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
      
  5. 读取图像文件:

    • 如果当前文件不是视频,增加计数后读取图像:
      self.count += 1
      im0 = cv2.imread(path)  # BGR
      assert im0 is not None, f'Image Not Found {path}'
      s = f'image {self.count}/{self.nf} {path}: '
      
    • 确保图像成功读取,未找到则抛出异常。
  6. 应用变换:

    • 如果定义了图像变换,则应用变换:
      if self.transforms:
          im = self.transforms(im0)  # transforms
      
    • 否则对图像进行填充和调整大小,以及通道转换:
      else:
          im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0]  # padded resize
          im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
          im = np.ascontiguousarray(im)  # contiguous
      
  7. 返回值:

    return path, im, im0, self.cap, s
    

    最后,返回文件路径、预处理后的图像、原始图像、视频捕播放器的引用和状态信息字符串。

该 __next__ 方法主要功能是实现一个迭代器,用于读取图像和视频文件。它能够判断当前文件类型,根据类型采取不同的读取方式,并且在读取完成后会自动处理图像的尺寸和格式。此方法确保了可以顺序访问数据集中的所有文件,适用于训练或推断过程中需要逐帧处理图像和视频的场景。通过抛出 StopIteration,可以优雅地结束迭代,符合 Python 迭代器协议。

5.4 _new_video(self, path)

  • 功能
    • 使用给定的视频路径创建一个新的视频捕捉对象,并获取视频帧数和方向。
    def _new_video(self, path):
        # Create a new video capture object
        self.frame = 0
        self.cap = cv2.VideoCapture(path)
        self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
        self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META))  # rotation degrees
        # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0)  # disable https://github.com/ultralytics/yolov5/issues/8493

下面是对提供的代码的逐步分解和详细解释:

def _new_video(self, path):
  • 定义了一个名为 _new_video 的方法,该方法属于一个类(self 指代类的实例)。
  • 这个方法接收一个参数 path,表示视频文件的路径。
    # Create a new video capture object
    self.frame = 0
  • 该行注释说明这个方法的功能是创建一个新的视频捕获对象。
  • 同时,self.frame 被初始化为 0,用于跟踪当前处理的视频帧数。
    self.cap = cv2.VideoCapture(path)

  • 使用 OpenCV 的 VideoCapture 类创建一个视频捕获对象,将其赋值给 self.cap,从而能够读取指定路径 path 中的视频文件。
    self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)

  • 这行代码使用 get 方法来获取视频的总帧数,具体属性 cv2.CAP_PROP_FRAME_COUNT 表示视频总帧数。
  • 通过将总帧数除以 self.vid_stride(视频帧率步幅)来计算有效的帧数,并将其赋值给 self.frames
    self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META))  # rotation degrees

  • 获取视频的方向属性(例如旋转角度),cv2.CAP_PROP_ORIENTATION_META 属性用于获取视频的方向信息。
  • 将获取的方向转换为整数并保存到 self.orientation 中。
    # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0)  # disable https://github.com/ultralytics/yolov5/issues/8493

  • 这一行是注释,表明如果需要,可以通过设置 cv2.CAP_PROP_ORIENTATION_AUTO 属性来禁用自动方向调整。这是在某些情况下可能导致的问题,用于参考。

该代码定义了一个方法 _new_video其主要功能是创建一个新的视频捕获对象,并初始化与该视频相关的几个重要属性。具体来说,该方法:

  • 初始化视频帧计数器。
  • 使用给定路径创建视频捕获对象。
  • 计算并保存视频的总帧数(除以步幅)以及视频的方向信息。

这个方法主要用于视频处理的上下文中,使得程序能够有效地读取和处理视频内容。

5.5 _cv2_rotate(self, im)

该方法 _cv2_rotate 的主要功能是根据图像的方向信息(由 self.orientation 提供),对输入的图像进行相应的旋转处理。通过判断方向的不同,该方法执行不同类型的旋转操作,并返回处理后的图像。这在处理来自相机或视频的帧时特别有用,因为不同的设备可能会以不同的方向保存图像数据。

  • 功能
    • 根据视频的方向信息手动调整图像的方向。
  •    def _cv2_rotate(self, im):
            # Rotate a cv2 video manually
            if self.orientation == 0:
                return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
            elif self.orientation == 180:
                return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
            elif self.orientation == 90:
                return cv2.rotate(im, cv2.ROTATE_180)
            return im

    这段代码定义了一个名为 _cv2_rotate 的方法,该方法主要用于根据图像的方向信息,对输入的图像进行旋转处理。以下是对代码的逐步分解和详细解释:

  • 方法定义

    def _cv2_rotate(self, im):
    

    该行定义了一个名为 _cv2_rotate 的方法,接受两个参数:

    • self:指向当前类实例的引用。
    • im:要进行旋转处理的图像(使用 OpenCV 读取的图像)。
  • 注释

    # Rotate a cv2 video manually
    

    这行注释表明该方法的目的是手动旋转一个使用 OpenCV 处理的视频帧或图像。

  • 判断图像方向

    if self.orientation == 0:
    

    代码检查实例属性 self.orientation 的值,根据其值决定如何旋转图像。

  • 旋转逻辑

    • 顺时针旋转90度

      return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
      

      如果方向为0,调用 cv2.rotate() 方法将图像顺时针旋转90度并返回。

    • 逆时针旋转90度

      elif self.orientation == 180:
          return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
      

      如果方向为180,调用 cv2.rotate() 将图像逆时针旋转90度并返回。

    • 旋转180度

      elif self.orientation == 90:
          return cv2.rotate(im, cv2.ROTATE_180)
      

      如果方向为90,调用 cv2.rotate() 将图像旋转180度并返回。

  • 返回原图像

    return im
    

    如果上述条件都不满足(即方向不是0、180或90),则直接返回原图像 im,表示不做任何旋转。

5.6 __len__(self)

  • 功能
    • 返回加载的图像或视频的总数量。
    def __len__(self):
        return self.nf  # number of files

代码片段中的 __len__ 方法是一个特殊方法,在Python中用于定义一个对象的长度或大小。这段代码的功能如下:

  1. 方法定义

    • def __len__(self)::这是一个类的方法,self 参数代表类的实例。在Python中,任何具有 __len__ 方法的对象都可以通过 len() 函数获得其长度。
  2. 返回值

    • return self.nfnf 是一个属性,通常在类的构造函数(__init__ 方法)中被定义和初始化。它表示一个特定的数值,通常是文件的数量。在这个上下文中,nf 表示当前对象中处理的文件总数。

此方法主要用于让类的实例支持通过 len() 函数来获取该对象中包含的文件数量。这是在处理与数据集(如图像或视频文件)相关的任务时非常有用的。当你调用 len(instance) 时,这个方法会被自动调用,返回该实例中包含的文件总数,为数据处理和管理提供方便。

5.7 总结

LoadImages 类的主要功能是从指定路径加载图像和视频文件,并进行预处理,以便在 YOLOv5 模型中使用。它能够处理多种输入格式,支持图像和视频的混合加载,自动判断所需的预处理操作,确保数据在训练或推理过程中的一致性和效率。此类实现了 Python 的迭代器协议,允许用户以迭代的方式访问和处理图像数据。

六、LoadStreams

class LoadStreams:
    # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4'  # RTSP, RTMP, HTTP streams`
    def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
        torch.backends.cudnn.benchmark = True  # faster for fixed-size inference
        self.mode = 'stream'
        self.img_size = img_size
        self.stride = stride
        self.vid_stride = vid_stride  # video frame-rate stride
        sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
        n = len(sources)
        self.sources = [clean_str(x) for x in sources]  # clean source names for later
        self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
        for i, s in enumerate(sources):  # index, source
            # Start thread to read frames from video stream
            st = f'{i + 1}/{n}: {s}... '
            if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'):  # if source is YouTube video
                # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
                check_requirements(('pafy', 'youtube_dl==2020.12.2'))
                import pafy
                s = pafy.new(s).getbest(preftype="mp4").url  # YouTube URL
            s = eval(s) if s.isnumeric() else s  # i.e. s = '0' local webcam
            if s == 0:
                assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
                assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
            cap = cv2.VideoCapture(s)
            assert cap.isOpened(), f'{st}Failed to open {s}'
            w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            fps = cap.get(cv2.CAP_PROP_FPS)  # warning: may return 0 or nan
            self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf')  # infinite stream fallback
            self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30  # 30 FPS fallback

            _, self.imgs[i] = cap.read()  # guarantee first frame
            self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
            LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
            self.threads[i].start()
        LOGGER.info('')  # newline

        # check for common shapes
        s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
        self.rect = np.unique(s, axis=0).shape[0] == 1  # rect inference if all shapes equal
        self.auto = auto and self.rect
        self.transforms = transforms  # optional
        if not self.rect:
            LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.')

    def update(self, i, cap, stream):
        # Read stream `i` frames in daemon thread
        n, f = 0, self.frames[i]  # frame number, frame array
        while cap.isOpened() and n < f:
            n += 1
            cap.grab()  # .read() = .grab() followed by .retrieve()
            if n % self.vid_stride == 0:
                success, im = cap.retrieve()
                if success:
                    self.imgs[i] = im
                else:
                    LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
                    self.imgs[i] = np.zeros_like(self.imgs[i])
                    cap.open(stream)  # re-open stream if signal was lost
            time.sleep(0.0)  # wait time

    def __iter__(self):
        self.count = -1
        return self

    def __next__(self):
        self.count += 1
        if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'):  # q to quit
            cv2.destroyAllWindows()
            raise StopIteration

        im0 = self.imgs.copy()
        if self.transforms:
            im = np.stack([self.transforms(x) for x in im0])  # transforms
        else:
            im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0])  # resize
            im = im[..., ::-1].transpose((0, 3, 1, 2))  # BGR to RGB, BHWC to BCHW
            im = np.ascontiguousarray(im)  # contiguous

        return self.sources, im, im0, None, ''

    def __len__(self):
        return len(self.sources)  # 1E12 frames = 32 streams at 30 FPS for 30 years

6.1 类定义和初始化 (__init__):

  • 类 LoadStreams: 该类用于从各种视频流(如 RTSP、RTMP、HTTP 等)中加载图像,广泛用于 YOLOv5 的实时检测。
  • 初始化参数:
    • sources: 视频流源,可以是文件名,或是包含多个源的文本文件。
    • img_sizestride: 图像大小和步幅,用于图像预处理。
    • auto: 自动调整标志。
    • transforms: 应用到图像的变换(如数据增强)。
    • vid_stride: 从视频帧中跳过的帧数。
  • 优化: 设置 cudnn.benchmark 为 True,优化固定尺寸输入的推断速度。
  • 源处理: 如果 sources 是文件,则从中读取源名称,否则将其直接用作源名称。随后清洗每个源名称。
  • 状态初始化: 初始化图像、帧速率和视频帧数的数组,以及用于读取视频流的线程。

6.2 循环处理每个源:

  • 遍历每个视频源,并为每个源创建一个线程以读取视频流。
  • YouTube 检查: 如果源是 YouTube 视频,则使用 pafy 库获取最佳视频流 URL。
  • 视频捕获: 使用 cv2.VideoCapture 打开视频流,并保存视频的宽度、高度和帧率。
  • 帧读取: 确保读取第一帧并启动线程以持续读取帧。

6.3 更新函数 (update):

  • 该函数在单独的线程中运行,负责从视频流中抓取帧并更新图像列表。
  • 使用 grab() 和 retrieve() 方法获取帧数据,如果在某个间隔未能成功捕获图像,则记录警告。
  • 如果视频流变得不可用,会尝试重新打开它。

6.4 迭代器 (__iter__ 和 __next__):

  • 遍历每个流的数据。
  • __next__方法会在所有帧流活跃且未按下退出键时返回下一帧。如果线程关闭或用户请求退出,则停止迭代。
  • 处理和返回图像,通过调整大小和变换来准备数据。

6.5 图像和数据处理:

  • 使用可选的图像变换来处理读取到的图像,如格式转换和大小调整,确保在形状不同时加入警告。

6.6 长度 (__len__):

  • 允许计算流的数量,返回当前加载的视频流的数量。

6.7 总结

LoadStreams 类旨在从不同的视频源加载实时数据流(如 RTSP 和 HTTP),并通过多线程处理抓取和预处理这些视频流。它的主要功能包括:

  • 多源支持: 能够支持多个视频流源,自动检测视频流类型并处理。
  • 帧处理: 实现了从视频流中抓取和处理帧的多线程机制,可以实时更新图像数据。
  • 形状处理与变换: 在图像提取后,进行尺寸、格式转换和数据增强,为后续的深度学习模型提供优化的数据输入。

通过这种设计,用户可以轻松地将视频流整合到 YOLOv5 的检测管道中,增强了模型在实际应用中的实时处理能力。

七、通用函数

7.1 函数 img2label_paths

def img2label_paths(img_paths):
    # Define label paths as a function of image paths
    sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}'  # /images/, /labels/ substrings
    return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]

这段代码定义了一个名为 img2label_paths 的函数,其主要功能是根据给定的图像路径列表生成相应的标签文件路径列表。在计算机视觉项目中,图像通常会有相应的标签文件来描述图像中的对象和它们的位置,特别是在使用如 YOLO 等目标检测算法时。

  1. 函数定义

    def img2label_paths(img_paths):
    

    这里定义一个名为 img2label_paths 的函数,参数 img_paths 是一个包含图像路径的列表。

  2. 定义路径分隔符

    sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}'
    
    • sa 和 sb 是两个字符串,分别表示 /images/ 和 /labels/ 的路径格式。os.sep 是操作系统特定的路径分隔符(在Windows上是 \,在Linux和Mac上是 /)。
  3. 生成标签路径

    return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
    
    • 这是一个列表推导式,它遍历 img_paths 中的每个图像路径 x
    • 对于每个路径 x
      • x.rsplit(sa, 1):从右侧分割字符串,只分割一次,得到一个列表,其中第一个元素是切分后的主路径部分,第二个元素是 images 部分。
      • sb.join(...):然后将主路径部分与 sb(即 /labels/)结合,形成新的路径。
      • rsplit('.', 1)[0]:再从结果字符串右侧分割一次,去掉文件扩展名(如 .jpg.png等)。
      • 最后,加上 '.txt' 后缀,从而形成对应的标签文件路径。:

img2label_paths 函数的主要功能是为一组图像路径生成对应的标签路径。其具体操作是通过替换路径中的 /images/ 部分为 /labels/,并保证标签文件以 .txt 为后缀。这在处理数据集时非常有用,尤其是在需要将图像与其标签相匹配的时候,常用于机器学习和深度学习项目中的数据预处理阶段。

 7.2 verify_image_label

这个函数用于检查每一张图片和每一张label文件是否完好。

图片文件: 检查内容、格式、大小、完整性

label文件: 检查每个gt必须是矩形(每行都得是5个数 class+xywh) + 标签是否全部>=0 + 标签坐标xywh是否归一化 + 标签中是否有重复的坐标

def verify_image_label(args):
    # Verify one image-label pair
    im_file, lb_file, prefix = args
    nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', []  # number (missing, found, empty, corrupt), message, segments
    try:
        # verify images
        im = Image.open(im_file)
        im.verify()  # PIL verify
        shape = exif_size(im)  # image size
        assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
        assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
        if im.format.lower() in ('jpg', 'jpeg'):
            with open(im_file, 'rb') as f:
                f.seek(-2, 2)
                if f.read() != b'\xff\xd9':  # corrupt JPEG
                    ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
                    msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'

        # verify labels
        if os.path.isfile(lb_file):
            nf = 1  # label found
            with open(lb_file) as f:
                lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
                if any(len(x) > 6 for x in lb):  # is segment
                    classes = np.array([x[0] for x in lb], dtype=np.float32)
                    segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb]  # (cls, xy1...)
                    lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1)  # (cls, xywh)
                lb = np.array(lb, dtype=np.float32)
            nl = len(lb)
            if nl:
                assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
                assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
                assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
                _, i = np.unique(lb, axis=0, return_index=True)
                if len(i) < nl:  # duplicate row check
                    lb = lb[i]  # remove duplicates
                    if segments:
                        segments = [segments[x] for x in i]
                    msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
            else:
                ne = 1  # label empty
                lb = np.zeros((0, 5), dtype=np.float32)
        else:
            nm = 1  # label missing
            lb = np.zeros((0, 5), dtype=np.float32)
        return im_file, lb, shape, segments, nm, nf, ne, nc, msg
    except Exception as e:
        nc = 1
        msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
        return [None, None, None, None, nm, nf, ne, nc, msg]

这段代码定义了一个名为verify_image_label的函数,用于验证图像和其对应标签的有效性。以下是对代码的逐步分解和详细解释:

  1. 函数定义及参数

    def verify_image_label(args):
    
    • 函数接收一个参数args,它是包含三个元素的元组:图像文件路径im_file、标签文件路径lb_file以及用于消息显示的前缀prefix
  2. 初始化变量

    nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', []
    
    • nm:缺失的标签数目。
    • nf:找到的标签数目。
    • ne:空标签数量。
    • nc:损坏的标签数量。
    • msg:用于存储警告消息的字符串。
    • segments:存储图像分段信息的列表。
  3. 图像验证

    im = Image.open(im_file)
    im.verify()  # PIL verify
    
    • 加载指定的图像文件并使用PIL库验证图像的完整性。
  4. 获取图像尺寸

    shape = exif_size(im)
    assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
    
    • 获取图像的尺寸并确保其大于10像素,如果不符合,则触发一个断言异常。
  5. 格式检查

    assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
    
    • 确认图像的格式是否在允许的格式中。
  6. JPEG特有检查(如果格式为JPEG):

    if im.format.lower() in ('jpg', 'jpeg'):
        with open(im_file, 'rb') as f:
            f.seek(-2, 2)
            if f.read() != b'\xff\xd9':  # corrupt JPEG
                ...
    
    • 检查JPEG文件的结尾标志,若缺失,则认为该JPEG文件可能受损并保存修复后的图像。
  7. 标签验证

    if os.path.isfile(lb_file):
    
    • 检查标签文件是否存在,如果存在则读取并解析标签内容。
  8. 处理标签内容

    • 将标签内容分拆为类和坐标,并对其进行验证,包括:
      • 确认标签的列数为5(class x, y, width, height)。
      • 确保所有标签的值均为非负并且在适当范围内。
      • 检查是否有重复的标签,若有则移除。
  9. 返回值

    return im_file, lb, shape, segments, nm, nf, ne, nc, msg
    
    • 返回一个包含图像文件名、标签数组、图像尺寸、图像分段、缺失的标签数、找到的标签数、空标签数、损坏标签数和警告信息的元组。
  10. 异常处理

    except Exception as e:
        nc = 1
        msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
        return [None, None, None, None, nm, nf, ne, nc, msg]
    
    • 捕获异常,增加损坏计数,显示相应警告信息,并返回相关数据。

verify_image_label函数的主要功能是验证给定图像及其对应标签文件的有效性。它通过检查图像的读取、格式、尺寸合法性以及标签的完整性与合理性来确保数据的有效性。若发现问题,该函数会返回详细的错误信息和统计数据,便于后续处理和调试。总之,此函数在数据准备和清洗过程中具有重要作用,有助于提升训练模型的数据质量。

7.3 flatten_recursive

这个模块是将一个文件路径中的所有文件复制到另一个文件夹中 即将image文件和label文件放到一个新文件夹中。 

def flatten_recursive(path=DATASETS_DIR / 'coco128'):
    # Flatten a recursive directory by bringing all files to top level
    new_path = Path(f'{str(path)}_flat')
    if os.path.exists(new_path):
        shutil.rmtree(new_path)  # delete output folder
    os.makedirs(new_path)  # make new output folder
    for file in tqdm(glob.glob(f'{str(Path(path))}/**/*.*', recursive=True)):
        shutil.copyfile(file, new_path / Path(file).name)

这段代码的主要功能是将递归目录中的所有文件扁平化,即将它们移动到同一层级的目标文件夹中。下面逐步分解并详细解释代码:

  1. 函数定义与参数:

    def flatten_recursive(path=DATASETS_DIR / 'coco128'):
    

    定义了一个名为 flatten_recursive 的函数,默认参数为 DATASETS_DIR / 'coco128'。该参数表示要处理的文件夹路径。

  2. 创建新路径:

    new_path = Path(f'{str(path)}_flat')
    

    创建一个新的路径 new_path,该路径是在原路径的基础上添加了 _flat 后缀的一个目录。这个目录将用于存放扁平化后的文件。

  3. 检查新路径是否存在:

    if os.path.exists(new_path):
        shutil.rmtree(new_path)  # delete output folder
    

    使用 os.path.exists() 检查 new_path 是否存在。如果存在,调用 shutil.rmtree() 删除该目录(及其内容),确保每次运行时都从一个干净的状态开始。

  4. 创建新目录:

    os.makedirs(new_path)  # make new output folder
    

    使用 os.makedirs() 创建新的输出目录 new_path

  5. 遍历递归文件夹:

    for file in tqdm(glob.glob(f'{str(Path(path))}/**/*.*', recursive=True)):
    
    • 使用 glob.glob() 函数查找 path 目录及其子目录下的所有文件。'**/*.*' 表示匹配所有文件(具有扩展名的文件)。recursive=True 选项允许在所有子目录中查找文件。
    • tqdm() 包装了 glob.glob() 的结果,以便在遍历时显示进度条。
  6. 复制文件:

    shutil.copyfile(file, new_path / Path(file).name)
    

    对于找到的每个文件,使用 shutil.copyfile() 将其复制到 new_path,同时仅使用文件名而不再保留原来的路径。

该函数 flatten_recursive 的主要功能是将指定目录及其子目录中的所有文件提取并复制到一个新的扁平化目录中。通过删除现有的输出目录并在开始时创建一个新的输出目录,确保每次运行时都从头开始,避免旧数据的干扰。使用 glob 模块递归查找文件和 shutil 模块处理文件的复制操作,同时结合 tqdm 显示进度,让用户可以直观地看到处理进度。

7.4 extract_boxes

这个模块是将目标检测数据集转化为分类数据集 ,集体做法: 把目标检测数据集中的每一个gt拆解开 分类别存储到对应的文件当中。

def extract_boxes(path=DATASETS_DIR / 'coco128'):  # from utils.dataloaders import *; extract_boxes()
    # Convert detection dataset into classification dataset, with one directory per class
    path = Path(path)  # images dir
    shutil.rmtree(path / 'classification') if (path / 'classification').is_dir() else None  # remove existing
    files = list(path.rglob('*.*'))
    n = len(files)  # number of files
    for im_file in tqdm(files, total=n):
        if im_file.suffix[1:] in IMG_FORMATS:
            # image
            im = cv2.imread(str(im_file))[..., ::-1]  # BGR to RGB
            h, w = im.shape[:2]

            # labels
            lb_file = Path(img2label_paths([str(im_file)])[0])
            if Path(lb_file).exists():
                with open(lb_file) as f:
                    lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32)  # labels

                for j, x in enumerate(lb):
                    c = int(x[0])  # class
                    f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg'  # new filename
                    if not f.parent.is_dir():
                        f.parent.mkdir(parents=True)

                    b = x[1:] * [w, h, w, h]  # box
                    # b[2:] = b[2:].max()  # rectangle to square
                    b[2:] = b[2:] * 1.2 + 3  # pad
                    b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(int)

                    b[[0, 2]] = np.clip(b[[0, 2]], 0, w)  # clip boxes outside of image
                    b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
                    assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'

以下是对代码的逐步分解和详细解释:

  1. 函数定义

    def extract_boxes(path=DATASETS_DIR / 'coco128'):
    

    该函数定义了一个名为 extract_boxes 的函数,参数 path 默认指向一个数据集目录(DATASETS_DIR/coco128)。

  2. 路径初始化

    path = Path(path)  # images dir
    

    使用 Path 将字符串路径转换为 Path 对象,以便更方便地进行文件操作。

  3. 删除旧的分类文件夹

    shutil.rmtree(path / 'classification') if (path / 'classification').is_dir() else None
    

    检查是否存在名为 classification 的目录,如果存在,则删除该目录。这是为了确保提取的箱子不会和之前的结果冲突。

  4. 获取所有文件列表

    files = list(path.rglob('*.*'))
    n = len(files)  # number of files
    

    使用 rglob 获取目录下所有文件(包括子目录),并将它们存储在 files 列表中,同时计算文件数量 n

  5. 遍历文件

    for im_file in tqdm(files, total=n):
    

    使用 tqdm 包装文件列表以便显示进度条,逐个处理每个文件。

  6. 检查文件格式

    if im_file.suffix[1:] in IMG_FORMATS:
    

    检查文件后缀是否在预定义的图像格式列表 IMG_FORMATS 中。

  7. 读取图像

    im = cv2.imread(str(im_file))[..., ::-1]  # BGR to RGB
    

    使用 OpenCV 读取图像,并将其从 BGR 转换为 RGB 格式。

  8. 获取图像尺寸

    h, w = im.shape[:2]
    

    获取图像的高度 h 和宽度 w

  9. 获取标签文件

    lb_file = Path(img2label_paths([str(im_file)])[0])
    

    基于图像文件路径生成相应的标签文件路径。

  10. 检查标签文件存在性

    if Path(lb_file).exists():
    

    检查标签文件是否存在。

  11. 读取标签内容

    with open(lb_file) as f:
        lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32)  # labels
    

    打开标签文件,并读取内容,将每一行拆分为列表,最终转换为 NumPy 数组 lb

  12. 处理每个标签

    for j, x in enumerate(lb):
    

    遍历标签数组中的每个标签。

  13. 创建新文件路径

    c = int(x[0])  # class
    f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg'  # new filename
    

    根据标签中的分类信息 c 创建新的文件夹和文件名,以便按类分类保存提取的图像。

  14. 确保文件夹存在

    if not f.parent.is_dir():
        f.parent.mkdir(parents=True)
    

    如果目标目录不存在,则创建目录。

  15. 计算边界框

    b = x[1:] * [w, h, w, h]  # box
    b[2:] = b[2:] * 1.2 + 3  # pad
    b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(int)
    

    从标签中获取边界框坐标,将其从相对位置转换为绝对像素值,并进行一定的扩展(padding)。

  16. 边界框裁剪

    b[[0, 2]] = np.clip(b[[0, 2]], 0, w)  # clip boxes outside of image
    b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
    

    保证边界框的坐标不超出图像边界。

  17. 保存裁剪图像

    assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
    

    使用 OpenCV 将裁剪后的图像保存到新的文件路径,确保保存成功。

此代码的主要功能是将目标检测数据集中的图像和相应的标签转换为分类数据集。具体步骤包括删除旧的分类目录、读取所有图像文件、获取标签、根据标签中的框信息提取相应的图像并分类存储。最终,每个物体的图像会被裁剪并保存到与其类别对应的子目录下。这种转换使得数据能够应用于分类任务,而不仅仅是检测任务。

7.5 autosplit

这个模块是进行自动划分数据集。当使用自己数据集时,可以用这个模块进行自行划分数据集。

def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
    """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
    Usage: from utils.dataloaders import *; autosplit()
    Arguments
        path:            Path to images directory
        weights:         Train, val, test weights (list, tuple)
        annotated_only:  Only use images with an annotated txt file
    """
    path = Path(path)  # images dir
    files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS)  # image files only
    n = len(files)  # number of files
    random.seed(0)  # for reproducibility
    indices = random.choices([0, 1, 2], weights=weights, k=n)  # assign each image to a split

    txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt']  # 3 txt files
    for x in txt:
        if (path.parent / x).exists():
            (path.parent / x).unlink()  # remove existing

    print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
    for i, img in tqdm(zip(indices, files), total=n):
        if not annotated_only or Path(img2label_paths([str(img)])[0]).exists():  # check label
            with open(path.parent / txt[i], 'a') as f:
                f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n')  # add image to txt file

这段代码定义了一个名为 autosplit 的函数,其功能是将数据集中的图像自动拆分为训练集、验证集和测试集,并将结果保存为文本文件(autosplit_train.txtautosplit_val.txt 和 autosplit_test.txt)。下面是对代码逐步分解和详细解释:

  1. 函数定义与文档字符串:

    def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
    
    • 定义名为 autosplit 的函数。
    • 参数说明:
      • path: 图像文件所在的目录,默认是 DATASETS_DIR 下的 coco128/images
      • weights: 用于划分训练集、验证集和测试集的权重,默认是 0.9(训练集)、0.1(验证集)、0.0(测试集)。
      • annotated_only: 如果为真,只使用有标注的图像。
  2. 路径和文件收集:

    path = Path(path)  # images dir
    files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS)  # image files only
    n = len(files)  # number of files
    
    • 将 path 转换为 Path 对象,方便进行路径操作。
    • 使用 rglob 方法遍历 path 目录下所有文件,并过滤出图像文件(判断文件后缀是否在 IMG_FORMATS 中)。然后对文件列表进行排序。
    • n 保存文件数。
  3. 随机数种子设定:

    random.seed(0)  # for reproducibility
    
    • 设置随机数种子为 0,以确保程序的可重复性。
  4. 图像索引分配:

    indices = random.choices([0, 1, 2], weights=weights, k=n)  # assign each image to a split
    
    • 随机从 [0, 1, 2] 中选择 n 个索引,使用给定的权重来区分训练、验证和测试集。
  5. 创建文本文件:

    txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt']  # 3 txt files
    for x in txt:
        if (path.parent / x).exists():
            (path.parent / x).unlink()  # remove existing
    
    • 定义要创建的三个文本文件的名称。
    • 如果文件已存在,则删除它们,确保新的文本文件不会与旧的文件冲突。
  6. 图像划分与写入文件:

    print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
    for i, img in tqdm(zip(indices, files), total=n):
        if not annotated_only or Path(img2label_paths([str(img)])[0]).exists():  # check label
            with open(path.parent / txt[i], 'a') as f:
                f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n')  # add image to txt file
    
    • 打印当前处理的路径,用于用户反馈。
    • 使用 tqdm 方法显示进度条,遍历图像列表。
    • 在每次循环中,根据索引确定图像所属的类别。
    • 检查是否只使用有标注的图像,如果是,确保图像对应的标签文件存在。
    • 将图像相对路径写入相应的文本文件。

该函数的主要功能是自动将指定路径下的图像数据集划分为训练集、验证集和测试集,并生成相应的文本文件以供后续使用。通过设置权重,可以灵活调整每个数据集的大小。此外,用户可以选择只使用有标注的图像,确保数据质量。这种数据分割方式在机器学习和深度学习的模型训练中非常常见,能够有效帮助管理和组织数据集。

7.6 函数get_hash

def get_hash(paths):
    # Returns a single hash value of a list of paths (files or dirs)
    size = sum(os.path.getsize(p) for p in paths if os.path.exists(p))  # sizes
    h = hashlib.md5(str(size).encode())  # hash sizes
    h.update(''.join(paths).encode())  # hash paths
    return h.hexdigest()  # return hash
  1. 函数定义:

    def get_hash(paths):
    

    这里定义了一个名为 get_hash 的函数,它接受一个参数 paths。这个参数预期是一个包含文件或目录路径的列表。

  2. 注释说明:

    # Returns a single hash value of a list of paths (files or dirs)
    

    注释说明这个函数的目的,即返回一个由路径生成的单一哈希值,路径可以是文件或目录。

  3. 计算路径大小:

    size = sum(os.path.getsize(p) for p in paths if os.path.exists(p))  # sizes
    

    这行代码使用生成器表达式来计算所有传入路径的总大小。os.path.getsize(p) 返回路径 p 所对应文件的大小(字节数),而 os.path.exists(p) 用来检查路径是否存在。只有当路径存在时,才会计算其大小。最终使用 sum() 函数将这些大小加起来。

  4. 创建 MD5 哈希对象:

    h = hashlib.md5(str(size).encode())  # hash sizes
    

    这里使用 hashlib 库创建一个 MD5 哈希对象 h。首先,将前面计算的 size 转换为字符串,并编码为字节字符串(使用 encode()),然后将这个字节字符串传递给 md5() 方法以初始化哈希。

  5. 更新哈希值:

    h.update(''.join(paths).encode())  # hash paths
    

    这行代码将所有路径连接为一个单一的字符串,并将其编码为字节字符串。接着,使用 update() 方法将这个字符串的哈希值添加到原有的哈希值中。这样,最终的哈希值将不仅包含文件大小的信息,还包含路径本身的信息。

  6. 返回哈希值:

    return h.hexdigest()  # return hash
    

    最后,使用 hexdigest() 方法将哈希对象 h 转换为十六进制字符串格式并返回。这就是最终生成的哈希值。

这段代码的主要功能是根据一组文件或目录的路径计算并返回一个唯一的 MD5 哈希值。这个哈希值由两个部分组成:一是所有路径所对应的文件的总大小,二是路径字符串本身。通过这种方式,该函数能够生成一个快照,反映出这些路径的状态,并且可以用于检测文件的一致性或变化。非常适合用来验证数据集的完整性或检查文件是否被修改。

八、依赖库

import contextlib
import glob
import hashlib
import json
import math
import os
import random
import shutil
import time
from itertools import repeat
from multiprocessing.pool import Pool, ThreadPool
from pathlib import Path
from threading import Thread
from urllib.parse import urlparse

import numpy as np
import psutil
import torch
import torch.nn.functional as F
import torchvision
import yaml
from PIL import ExifTags, Image, ImageOps
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from tqdm import tqdm

from utils.augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste,
                                 letterbox, mixup, random_perspective)
from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, check_dataset, check_requirements,
                           check_yaml, clean_str, cv2, is_colab, is_kaggle, segments2boxes, unzip_file, xyn2xy,
                           xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
from utils.torch_utils import torch_distributed_zero_first

# Parameters
HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm'  # include image suffixes
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv'  # include video suffixes
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true'  # global pin_memory for dataloaders

# Get orientation exif tag
for orientation in ExifTags.TAGS.keys():
    if ExifTags.TAGS[orientation] == 'Orientation':
        break

下面将逐步分解并详细解释给出的代码,并总结其主要功能。

8.1 导入模块:

import contextlib
import glob
import hashlib
import json
import math
import os
import random
import shutil
import time
from itertools import repeat
from multiprocessing.pool import Pool, ThreadPool
from pathlib import Path
from threading import Thread
from urllib.parse import urlparse

这部分代码导入了一系列用于处理文件、数据和并发的标准库模块。具体功能包括:

  • contextlib: 提供上下文管理功能。
  • glob: 用于查找符合特定模式的文件名。
  • hashlib: 提供安全哈希和消息摘要算法。
  • json: 处理JSON数据格式。
  • math: 提供数学函数。
  • os: 与操作系统交互。
  • random: 生成随机数。
  • shutil: 文件操作(如复制和删除文件)。
  • time: 时间处理。
  • Path: 对文件路径进行处理。
  • ThreadThreadPool: 用于多线程处理。

8.2 导入第三方库:

import numpy as np
import psutil
import torch
import torch.nn.functional as F
import torchvision
import yaml
from PIL import ExifTags, Image, ImageOps
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from tqdm import tqdm
  • numpy: 数组和矩阵操作。
  • psutil: 进程和系统监控。
  • torch: PyTorch深度学习框架。
  • torchvision: 图像处理功能和数据集。
  • yaml: 处理YAML格式的数据。
  • PIL: 图像处理库,提供图像的打开、操作和保存功能。
  • tqdm: 显示进度条的库,方便用户观察进度。

8.3 导入自定义模块:

from utils.augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste,
                                 letterbox, mixup, random_perspective)
from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, check_dataset, check_requirements,
                           check_yaml, clean_str, cv2, is_colab, is_kaggle, segments2boxes, unzip_file, xyn2xy,
                           xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
from utils.torch_utils import torch_distributed_zero_first

这部分从自定义的工具库中导入功能,包括数据增强、通用的帮助函数和PyTorch相关的工具。

8.4 参数定义:

HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm'  # include image suffixes
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv'  # include video suffixes
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true'  # global pin_memory for dataloaders

这部分定义了一些参数:

  • HELP_URL: 帮助文档的链接。
  • IMG_FORMATSVID_FORMATS: 支持的图像和视频文件后缀。
  • LOCAL_RANKRANK: 从环境变量中获取分布式训练的排名信息。
  • PIN_MEMORY: 指示数据加载器是否使用固定内存。

获取EXIF标签中方向信息:

for orientation in ExifTags.TAGS.keys():
    if ExifTags.TAGS[orientation] == 'Orientation':
        break

这段代码用于查找图像的EXIF数据中表示方向的标签,为后续处理图像时考虑旋转等问题做准备。

8.5 总结

这段代码主要功能是为YOLOv5(一个目标检测模型)提供数据加载和预处理的基础框架。它导入了所需的库、定义了参数并设置了图像和视频的文件格式,获取EXIF方向信息为后续处理做准备。可以认为这是一段为模型训练或推理设置数据管道的初始化代码。

参考:

  1. YOLOv5系列(二十三) 解析数据集处理部分dataloaders(详尽)
  2. 第十四篇—创建数据集(YOLOv5专题)​​​​​​​
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值