【Mindspore】从Pytorch迁移至Mindpore——数据处理(实战篇)

前言

在上一篇文章我们提到如何在数据处理层面从Pytorch迁移到Mindspore,主要提到了基础数据集导入对比和自定义数据集导入,详细对比二者导入方式的不同。

在本篇文章,我们将以Yolov5源码为例子,详细介绍如何在数据集导入和处理上从Pytorch迁移到Mindspore。

本文章用到的代码:
https://github.com/Linorman/mindyolo-transfer.git
https://github.com/Linorman/yolov5-transfer.git
参考代码为:
https://github.com/ultralytics/yolov5.git
https://github.com/mindspore-lab/mindyolo.git

Yolov5在Pytorch中的数据加载

我们进入Yolov5源码的仓库,观察官方提供的源码如何对数据集进行操作,可以clone下来在本地查看。

仓库:https://github.com/ultralytics/yolov5.git

首先,我们知道在使用Yolov5官方源码训练时,多种数据集格式可以被支持,我们在这里以COCO数据集格式为例,格式如下:

			COCO_ROOT
                ├── train2017.txt
                ├── annotations
                │     └── instances_train2017.json
                ├── images
                │     └── train2017
                │             ├── 000000000001.jpg
                │             └── 000000000002.jpg
                └── labels
                      └── train2017
                              ├── 000000000001.txt
                              └── 000000000002.txt
            dataset_path (str): ./coco/train2017.txt

然后我们进入train.py,看看代码是如何导入数据集的:

# Trainloader
train_loader, dataset = create_dataloader(train_path,
                                          imgsz,
                                          batch_size // WORLD_SIZE,
                                          gs,
                                          single_cls,
                                          hyp=hyp,
                                          augment=True,
                                          cache=None if opt.cache == 'val' else opt.cache,
                                          rect=opt.rect,
                                          rank=LOCAL_RANK,
                                          workers=workers,
                                          image_weights=opt.image_weights,
                                          quad=opt.quad,
                                          prefix=colorstr('train: '),
                                          shuffle=True,
                                          seed=opt.seed)
labels = np.concatenate(dataset.labels, 0)
mlc = int(labels[:, 0].max())  # max label class
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'

从代码中可以看出,数据集的导入主要是由create_dataloader这个函数完成的,我们跳转到这个函数:

def create_dataloader(path,
                      imgsz,
                      batch_size,
                      stride,
                      single_cls=False,
                      hyp=None,
                      augment=False,
                      cache=False,
                      pad=0.0,
                      rect=False,
                      rank=-1,
                      workers=8,
                      image_weights=False,
                      quad=False,
                      prefix='',
                      shuffle=False,
                      seed=0):
    if rect and shuffle:
        LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
        shuffle = False
    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
        dataset = LoadImagesAndLabels(
            path,
            imgsz,
            batch_size,
            augment=augment,  # augmentation
            hyp=hyp,  # hyperparameters
            rect=rect,  # rectangular batches
            cache_images=cache,
            single_cls=single_cls,
            stride=int(stride),
            pad=pad,
            image_weights=image_weights,
            prefix=prefix)

    batch_size = min(batch_size, len(dataset))
    nd = torch.cuda.device_count()  # number of CUDA devices
    nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])  # number of workers
    sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
    loader = DataLoader if image_weights else InfiniteDataLoader  # only DataLoader allows for attribute updates
    generator = torch.Generator()
    generator.manual_seed(6148914691236517205 + seed + RANK)
    return loader(dataset,
                  batch_size=batch_size,
                  shuffle=shuffle and sampler is None,
                  num_workers=nw,
                  sampler=sampler,
                  pin_memory=PIN_MEMORY,
                  collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
                  worker_init_fn=seed_worker,
                  generator=generator), dataset

在这个函数中,包含了dataset的创建和loader的创建,接下来我们来看看LoadImagesAndLabels类的实现:

class LoadImagesAndLabels(Dataset):
    # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
    cache_version = 0.6  # dataset labels *.cache version
    rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]

    def __init__(self,
                 path,
                 img_size=640,
                 batch_size=16,
                 augment=False,
                 hyp=None,
                 rect=False,
                 image_weights=False,
                 cache_images=False,
                 single_cls=False,
                 stride=32,
                 pad=0.0,
                 min_items=0,
                 prefix=''):
        self.img_size = img_size
        self.augment = augment
        self.hyp = hyp
        self.image_weights = image_weights
        self.rect = False if image_weights else rect
        self.mosaic = self.augment and not self.rect  # load 4 images at a time into a mosaic (only during training)
        self.mosaic_border = [-img_size // 2, -img_size // 2]
        self.stride = stride
        self.path = path
        self.albumentations = Albumentations(size=img_size) if augment else None

        try:
            f = []  # image files
            for p in path if isinstance(path, list) else [path]:
                p = Path(p)  # os-agnostic
                if p.is_dir():  # dir
                    f += glob.glob(str(p / '**' / '*.*'), recursive=True)
                    # f = list(p.rglob('*.*'))  # pathlib
                elif p.is_file():  # file
                    with open(p) as t:
                        t = t.read().strip().splitlines()
                        parent = str(p.parent) + os.sep
                        f += [x.replace('./', parent, 1) if x.startswith('./') else x for x in t]  # to global path
                        # f += [p.parent / x.lstrip(os.sep) for x in t]  # to global path (pathlib)
                else:
                    raise FileNotFoundError(f'{prefix}{p} does not exist')
            self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
            # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS])  # pathlib
            assert self.im_files, f'{prefix}No images found'
        except Exception as e:
            raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e

        # Check cache
        self.label_files = img2label_paths(self.im_files)  # labels
        cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
        try:
            cache, exists = np.load(cache_path, allow_pickle=True).item(), True  # load dict
            assert cache['version'] == self.cache_version  # matches current version
            assert cache['hash'] == get_hash(self.label_files + self.im_files)  # identical hash
        except Exception:
            cache, exists = self.cache_labels(cache_path, prefix), False  # run cache ops

        # Display cache
        nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupt, total
        if exists and LOCAL_RANK in {-1, 0}:
            d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
            tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT)  # display cache results
            if cache['msgs']:
                LOGGER.info('\n'.join(cache['msgs']))  # display warnings
        assert nf > 0 or not augment, f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}'

        # Read cache
        [cache.pop(k) for k in ('hash', 'version', 'msgs')]  # remove items
        labels, shapes, self.segments = zip(*cache.values())
        nl = len(np.concatenate(labels, 0))  # number of labels
        assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}'
        self.labels = list(labels)
        self.shapes = np.array(shapes)
        self.im_files = list(cache.keys())  # update
        self.label_files = img2label_paths(cache.keys())  # update

        # Filter images
        if min_items:
            include = np.array([len(x) >= min_items for x in self.labels]).nonzero()[0].astype(int)
            LOGGER.info(f'{prefix}{n - len(include)}/{n} images filtered from dataset')
            self.im_files = [self.im_files[i] for i in include]
            self.label_files = [self.label_files[i] for i in include]
            self.labels = [self.labels[i] for i in include]
            self.segments = [self.segments[i] for i in include]
            self.shapes = self.shapes[include]  # wh

        # Create indices
        n = len(self.shapes)  # number of images
        bi = np.floor(np.arange(n) / batch_size).astype(int)  # batch index
        nb = bi[-1] + 1  # number of batches
        self.batch = bi  # batch index of image
        self.n = n
        self.indices = range(n)

        # Update labels
        include_class = []  # filter labels to include only these classes (optional)
        self.segments = list(self.segments)
        include_class_array = np.array(include_class).reshape(1, -1)
        for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
            if include_class:
                j = (label[:, 0:1] == include_class_array).any(1)
                self.labels[i] = label[j]
                if segment:
                    self.segments[i] = [segment[idx] for idx, elem in enumerate(j) if elem]
            if single_cls:  # single-class training, merge all classes into 0
                self.labels[i][:, 0] = 0

        # Rectangular Training
        if self.rect:
            # Sort by aspect ratio
            s = self.shapes  # wh
            ar = s[:, 1] / s[:, 0]  # aspect ratio
            irect = ar.argsort()
            self.im_files = [self.im_files[i] for i in irect]
            self.label_files = [self.label_files[i] for i in irect]
            self.labels = [self.labels[i] for i in irect]
            self.segments = [self.segments[i] for i in irect]
            self.shapes = s[irect]  # wh
            ar = ar[irect]

            # Set training image shapes
            shapes = [[1, 1]] * nb
            for i in range(nb):
                ari = ar[bi == i]
                mini, maxi = ari.min(), ari.max()
                if maxi < 1:
                    shapes[i] = [maxi, 1]
                elif mini > 1:
                    shapes[i] = [1, 1 / mini]

            self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride

        # Cache images into RAM/disk for faster training
        if cache_images == 'ram' and not self.check_cache_ram(prefix=prefix):
            cache_images = False
        self.ims = [None] * n
        self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
        if cache_images:
            b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
            self.im_hw0, self.im_hw = [None] * n, [None] * n
            fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
            results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
            pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
            for i, x in pbar:
                if cache_images == 'disk':
                    b += self.npy_files[i].stat().st_size
                else:  # 'ram'
                    self.ims[i], self.im_hw0[i], self.im_hw[i] = x  # im, hw_orig, hw_resized = load_image(self, i)
                    b += self.ims[i].nbytes
                pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
            pbar.close()

    def check_cache_ram(self, safety_margin=0.1, prefix=''):
        # Check image caching requirements vs available memory
        b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
        n = min(self.n, 30)  # extrapolate from 30 random images
        for _ in range(n):
            im = cv2.imread(random.choice(self.im_files))  # sample image
            ratio = self.img_size / max(im.shape[0], im.shape[1])  # max(h, w)  # ratio
            b += im.nbytes * ratio ** 2
        mem_required = b * self.n / n  # GB required to cache dataset into RAM
        mem = psutil.virtual_memory()
        cache = mem_required * (1 + safety_margin) < mem.available  # to cache or not to cache, that is the question
        if not cache:
            LOGGER.info(f'{prefix}{mem_required / gb:.1f}GB RAM required, '
                        f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
                        f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
        return cache

    def cache_labels(self, path=Path('./labels.cache'), prefix=''):
        # Cache dataset labels, check images and read shapes
        x = {}  # dict
        nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number missing, found, empty, corrupt, messages
        desc = f'{prefix}Scanning {path.parent / path.stem}...'
        with Pool(NUM_THREADS) as pool:
            pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
                        desc=desc,
                        total=len(self.im_files),
                        bar_format=TQDM_BAR_FORMAT)
            for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
                nm += nm_f
                nf += nf_f
                ne += ne_f
                nc += nc_f
                if im_file:
                    x[im_file] = [lb, shape, segments]
                if msg:
                    msgs.append(msg)
                pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'

        pbar.close()
        if msgs:
            LOGGER.info('\n'.join(msgs))
        if nf == 0:
            LOGGER.warning(f'{prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
        x['hash'] = get_hash(self.label_files + self.im_files)
        x['results'] = nf, nm, ne, nc, len(self.im_files)
        x['msgs'] = msgs  # warnings
        x['version'] = self.cache_version  # cache version
        try:
            np.save(path, x)  # save cache for next time
            path.with_suffix('.cache.npy').rename(path)  # remove .npy suffix
            LOGGER.info(f'{prefix}New cache created: {path}')
        except Exception as e:
            LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}')  # not writeable
        return x

    def __len__(self):
        return len(self.im_files)

    # def __iter__(self):
    #     self.count = -1
    #     print('ran dataset iter')
    #     #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
    #     return self

    def __getitem__(self, index):
        index = self.indices[index]  # linear, shuffled, or image_weights

        hyp = self.hyp
        mosaic = self.mosaic and random.random() < hyp['mosaic']
        if mosaic:
            # Load mosaic
            img, labels = self.load_mosaic(index)
            shapes = None

            # MixUp augmentation
            if random.random() < hyp['mixup']:
                img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))

        else:
            # Load image
            img, (h0, w0), (h, w) = self.load_image(index)

            # Letterbox
            shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size  # final letterboxed shape
            img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
            shapes = (h0, w0), ((h / h0, w / w0), pad)  # for COCO mAP rescaling

            labels = self.labels[index].copy()
            if labels.size:  # normalized xywh to pixel xyxy format
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])

            if self.augment:
                img, labels = random_perspective(img,
                                                 labels,
                                                 degrees=hyp['degrees'],
                                                 translate=hyp['translate'],
                                                 scale=hyp['scale'],
                                                 shear=hyp['shear'],
                                                 perspective=hyp['perspective'])

        nl = len(labels)  # number of labels
        if nl:
            labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)

        if self.augment:
            # Albumentations
            img, labels = self.albumentations(img, labels)
            nl = len(labels)  # update after albumentations

            # HSV color-space
            augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])

            # Flip up-down
            if random.random() < hyp['flipud']:
                img = np.flipud(img)
                if nl:
                    labels[:, 2] = 1 - labels[:, 2]

            # Flip left-right
            if random.random() < hyp['fliplr']:
                img = np.fliplr(img)
                if nl:
                    labels[:, 1] = 1 - labels[:, 1]

            # Cutouts
            # labels = cutout(img, labels, p=0.5)
            # nl = len(labels)  # update after cutout

        labels_out = torch.zeros((nl, 6))
        if nl:
            labels_out[:, 1:] = torch.from_numpy(labels)

        # Convert
        img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
        img = np.ascontiguousarray(img)

        return torch.from_numpy(img), labels_out, self.im_files[index], shapes

    def load_image(self, i):
        # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
        im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
        if im is None:  # not cached in RAM
            if fn.exists():  # load npy
                im = np.load(fn)
            else:  # read image
                im = cv2.imread(f)  # BGR
                assert im is not None, f'Image Not Found {f}'
            h0, w0 = im.shape[:2]  # orig hw
            r = self.img_size / max(h0, w0)  # ratio
            if r != 1:  # if sizes are not equal
                interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
                im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp)
            return im, (h0, w0), im.shape[:2]  # im, hw_original, hw_resized
        return self.ims[i], self.im_hw0[i], self.im_hw[i]  # im, hw_original, hw_resized

    def cache_images_to_disk(self, i):
        # Saves an image as an *.npy file for faster loading
        f = self.npy_files[i]
        if not f.exists():
            np.save(f.as_posix(), cv2.imread(self.im_files[i]))

    def load_mosaic(self, index):
        # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
        labels4, segments4 = [], []
        s = self.img_size
        yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border)  # mosaic center x, y
        indices = [index] + random.choices(self.indices, k=3)  # 3 additional image indices
        random.shuffle(indices)
        for i, index in enumerate(indices):
            # Load image
            img, _, (h, w) = self.load_image(index)

            # place img in img4
            if i == 0:  # top left
                img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
                x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc  # xmin, ymin, xmax, ymax (large image)
                x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h  # xmin, ymin, xmax, ymax (small image)
            elif i == 1:  # top right
                x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
                x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
            elif i == 2:  # bottom left
                x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
                x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
            elif i == 3:  # bottom right
                x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
                x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)

            img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b]  # img4[ymin:ymax, xmin:xmax]
            padw = x1a - x1b
            padh = y1a - y1b

            # Labels
            labels, segments = self.labels[index].copy(), self.segments[index].copy()
            if labels.size:
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh)  # normalized xywh to pixel xyxy format
                segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
            labels4.append(labels)
            segments4.extend(segments)

        # Concat/clip labels
        labels4 = np.concatenate(labels4, 0)
        for x in (labels4[:, 1:], *segments4):
            np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
        # img4, labels4 = replicate(img4, labels4)  # replicate

        # Augment
        img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
        img4, labels4 = random_perspective(img4,
                                           labels4,
                                           segments4,
                                           degrees=self.hyp['degrees'],
                                           translate=self.hyp['translate'],
                                           scale=self.hyp['scale'],
                                           shear=self.hyp['shear'],
                                           perspective=self.hyp['perspective'],
                                           border=self.mosaic_border)  # border to remove

        return img4, labels4

    def load_mosaic9(self, index):
        # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
        labels9, segments9 = [], []
        s = self.img_size
        indices = [index] + random.choices(self.indices, k=8)  # 8 additional image indices
        random.shuffle(indices)
        hp, wp = -1, -1  # height, width previous
        for i, index in enumerate(indices):
            # Load image
            img, _, (h, w) = self.load_image(index)

            # place img in img9
            if i == 0:  # center
                img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
                h0, w0 = h, w
                c = s, s, s + w, s + h  # xmin, ymin, xmax, ymax (base) coordinates
            elif i == 1:  # top
                c = s, s - h, s + w, s
            elif i == 2:  # top right
                c = s + wp, s - h, s + wp + w, s
            elif i == 3:  # right
                c = s + w0, s, s + w0 + w, s + h
            elif i == 4:  # bottom right
                c = s + w0, s + hp, s + w0 + w, s + hp + h
            elif i == 5:  # bottom
                c = s + w0 - w, s + h0, s + w0, s + h0 + h
            elif i == 6:  # bottom left
                c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
            elif i == 7:  # left
                c = s - w, s + h0 - h, s, s + h0
            elif i == 8:  # top left
                c = s - w, s + h0 - hp - h, s, s + h0 - hp

            padx, pady = c[:2]
            x1, y1, x2, y2 = (max(x, 0) for x in c)  # allocate coords

            # Labels
            labels, segments = self.labels[index].copy(), self.segments[index].copy()
            if labels.size:
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady)  # normalized xywh to pixel xyxy format
                segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
            labels9.append(labels)
            segments9.extend(segments)

            # Image
            img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:]  # img9[ymin:ymax, xmin:xmax]
            hp, wp = h, w  # height, width previous

        # Offset
        yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border)  # mosaic center x, y
        img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]

        # Concat/clip labels
        labels9 = np.concatenate(labels9, 0)
        labels9[:, [1, 3]] -= xc
        labels9[:, [2, 4]] -= yc
        c = np.array([xc, yc])  # centers
        segments9 = [x - c for x in segments9]

        for x in (labels9[:, 1:], *segments9):
            np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
        # img9, labels9 = replicate(img9, labels9)  # replicate

        # Augment
        img9, labels9, segments9 = copy_paste(img9, labels9, segments9, p=self.hyp['copy_paste'])
        img9, labels9 = random_perspective(img9,
                                           labels9,
                                           segments9,
                                           degrees=self.hyp['degrees'],
                                           translate=self.hyp['translate'],
                                           scale=self.hyp['scale'],
                                           shear=self.hyp['shear'],
                                           perspective=self.hyp['perspective'],
                                           border=self.mosaic_border)  # border to remove

        return img9, labels9

    @staticmethod
    def collate_fn(batch):
        im, label, path, shapes = zip(*batch)  # transposed
        for i, lb in enumerate(label):
            lb[:, 0] = i  # add target image index for build_targets()
        return torch.stack(im, 0), torch.cat(label, 0), path, shapes

    @staticmethod
    def collate_fn4(batch):
        im, label, path, shapes = zip(*batch)  # transposed
        n = len(shapes) // 4
        im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]

        ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
        wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
        s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]])  # scale
        for i in range(n):  # zidane torch.zeros(16,3,720,1280)  # BCHW
            i *= 4
            if random.random() < 0.5:
                im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
                                    align_corners=False)[0].type(im[i].type())
                lb = label[i]
            else:
                im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
                lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
            im4.append(im1)
            label4.append(lb)

        for i, lb in enumerate(label4):
            lb[:, 0] = i  # add target image index for build_targets()

        return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4

源代码很长,有1000行左右,我们先来看看代码中包含了哪些函数:

  • init(self, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False, cache_images=False, single_cls=False):
  • check_cache_ram(self, safety_margin=0.1, prefix=‘’):
  • cache_labels(self, path=Path(‘./labels.cache’), prefix=‘’):
  • len(self):
  • getitem(self, index):
  • load_image(self, i):
  • cache_images_to_disk(self, i):
  • load_mosaic(self, index):
  • collate_fn(batch):
  • collate_fn4(batch):

其中,init函数负责类的初始化,主要是对yaml文件中的参数(如数据集路径、标注路径等)处理,使得图像可以导入到ram/disk中,加快训练速度,在迁移过程中代码不用改动,所以我们略过这个函数不再赘述;__len__函数也很简单,只有一行代码,大家自行迁移。

同样的,load_imagecheck_cache_ramcache_images_to_disk都是与上面提到的相关的函数,也不再赘述。

我们重点来看看__getitem__函数,我将代码解释写在了注释中:

def __getitem__(self, index):
    # 获取样本的真实索引
    index = self.indices[index]
    # 是否使用 Mosaic 数据增强
    mosaic = self.mosaic and random.random() < hyp['mosaic']
    if mosaic:
        # 加载 Mosaic 图像和标签
        img, labels = self.load_mosaic(index)
        if random.random() < hyp['mixup']:
            # 使用 MixUp 数据增强
            mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
    else:
        # 加载单个图像并返回图像、原始大小和缩放后大小
        img, (h0, w0), (h, w) = self.load_image(index)
        # 根据是否使用矩形标注框选择最终的图像形状
        shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size
        # 对图像进行缩放和填充,返回缩放后的图像、缩放比例和填充信息
        img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
        # 保存原始大小、缩放比例和填充信息,用于 COCO mAP 的重缩放
        shapes = (h0, w0), ((h / h0, w / w0), pad)
        # 复制标签,确保不会修改原始标签
        labels = self.labels[index].copy()
        if labels.size:
            # 将标签从归一化的 xywh 格式转换为像素值的 xyxy 格式,并根据缩放和填充信息调整坐标
            labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])

        if self.augment:
            # 对图像和标签进行随机透视变换
            img, labels = random_perspective(img, labels, degrees=hyp['degrees'], translate=hyp['translate'], scale=hyp['scale'], shear=hyp['shear'], perspective=hyp['perspective'])
    nl = len(labels)
    if nl:
        # 将标签从像素值的 xyxy 格式转换为归一化的 xywh 格式,以便后续处理
        labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
    if self.augment:
        # 使用 Albumentations 库对图像和标签进行进一步的数据增强
        img, labels = self.albumentations(img, labels)
        # 对图像进行 HSV 颜色空间的随机变换
        augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
        # 根据超参数决定是否上下翻转图像
        if random.random() < hyp['flipud']:
            img = np.flipud(img)
            labels[:, 2] = 1 - labels[:, 2]
        # 根据超参数决定是否左右翻转图像
        if random.random() < hyp['fliplr']:
            img = np.fliplr(img)
            labels[:, 1] = 1 - labels[:, 1]
    # 创建全零张量用于保存标签信息
    labels_out = torch.zeros((nl, 6))
    # 将处理后的标签信息保存到 labels_out 张量中
    if nl:
        labels_out[:, 1:] = torch.from_numpy(labels)
    # 将图像转换为张量并进行归一化
    img = torch.from_numpy(img) / 255.0
    # 转换图像通道顺序为 C,H,W 并返回
    return img.permute(2, 0, 1), labels_out, index, shapes

可以看到根据超参数的设置,这个函数对图像做了多种变化,包括但不限于mosaic、mixup等。这些函数不需要迁移,可以直接使用。

注意到,在这段代码中提到了类的一个属性:albumentations,经常做图像处理的同学应该知道这个库是常用的图像处理库,相较于Pytorch自带的图像处理方法,这个库中的函数执行效率更快,方法更全。但是在这里使用的是Yolov5自己封装的图像处理类,代码如下:

class Albumentations:
    def __init__(self, size=640):
        self.transform = None
        prefix = colorstr('albumentations: ')
        try:
            import albumentations as A
            check_version(A.__version__, '1.0.3', hard=True)  # 版本要求

            T = [
                A.RandomResizedCrop(height=size, width=size, scale=(0.8, 1.0), ratio=(0.9, 1.11), p=0.0), # 随机裁剪
                A.Blur(p=0.01), # 模糊
                A.MedianBlur(p=0.01), # 中值模糊
                A.ToGray(p=0.01), # 灰度化
                A.CLAHE(p=0.01), # 对比度受限的自适应直方图均衡化
                A.RandomBrightnessContrast(p=0.0), # 随机亮度和对比度
                A.RandomGamma(p=0.0), # 随机伽马校正
                A.ImageCompression(quality_lower=75, p=0.0)]  # 图像压缩

            # 组合变换
            self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

            LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
        except ImportError:  # 没有安装此包
            pass
        except Exception as e:
            LOGGER.info(f'{prefix}{e}')

    def __call__(self, im, labels, p=1.0):
        if self.transform and random.random() < p:
            # 对图像和标签应用变换
            new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0])  # transformed
            im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
        return im, labels

鉴于使用该类可以提升代码效率,所以不建议重新实现该类,除非有特殊需求。

我们再来看看Dataloader的实现部分,代码的解释以注释形式给出:

batch_size = min(batch_size, len(dataset))  # 确保批处理大小不超过数据集的大小
nd = torch.cuda.device_count()  # CUDA 设备数量
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])  # 工作线程数量
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)  # 分布式采样器,如果 rank=-1,则不使用分布式采样
loader = DataLoader
generator = torch.Generator()  # 生成随机数的生成器
generator.manual_seed(6148914691236517205 + seed + RANK)  # 设置生成器的种子值
return loader(dataset,
              batch_size=batch_size,
              shuffle=shuffle and sampler is None,  # 如果不使用分布式采样,则 shuffle=True
              num_workers=nw,  # 设置工作线程数量
              sampler=sampler,  # 设置采样器
              pin_memory=PIN_MEMORY,  # 是否将数据存储在 CUDA 的固定内存中
              collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,  # 数据加载和标签的处理函数
              worker_init_fn=seed_worker,  # 工作线程的初始化函数
              generator=generator), dataset  # 返回数据加载器和数据集对象

以上即为Pytorch版本的代码实现,下面我们将对一些函数使用Mindspore进行改写。

Yolov5在Mindspore中的数据加载改写

接下来我们使用Mindspore对Yolov5算法进行实现,在改写中我们需要注意以下几点:

  1. GeneratorDataset需要传入数据集的colums
  2. GeneratorDataset需要在dataset.config中调用set_seed方法来传入随机种子,而不是在GeneratorDataset实例化中传入
  3. GeneratorDataset设置batch_size需要在实例化后进行

所以我们并不能照搬Pytorch代码中的结构,而是需要做出一定的调整,首先我们来看dataset部分。为了方便起见,并不使用文件存储超参数,而是直接将一些可能会用到的超参数放进类的init函数中,代码如下:

class MyDataset:
    def __init__(
        self,
        dataset_path="",         # 数据集路径
        img_size=640,            # 图像大小
        transforms_dict=None,    # 图像变换操作的字典
        is_training=False,       
        augment=False,           # 是否进行数据增强
        rect=False,              
        single_cls=False,        # 是否为单一类别
        batch_size=32,           # 批量大小
        stride=32,               # 步长
        num_cls=80,              # 类别数量
        pad=0.0,                 # 填充像素
        return_segments=False,   # 是否返回分割结果
        return_keypoints=False,  # 是否返回关键点结果
        nkpt=0,                  
        ndim=0                   
    ):
        # 支持的图像后缀数组
        self.img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo']
        self.cache_version = 0.2

        self.return_segments = return_segments
        self.return_keypoints = return_keypoints
        assert not (return_segments and return_keypoints), 'Can not return both segments and keypoints.'

        self.path = dataset_path
        self.img_size = img_size
        self.augment = augment
        self.rect = rect
        self.stride = stride
        self.num_cls = num_cls
        self.nkpt = nkpt
        self.ndim = ndim
        self.transforms_dict = transforms_dict
        self.is_training = is_training

        # 'self.column_names_getitem' 用于存储在获取单个样本时需要返回的数据项的名称
		self.column_names_getitem = ['samples']

		# 训练模式
		if self.is_training:
    		# 'self.column_names_collate' 用于存储在整合批次样本时需要返回的数据项的名称
    		self.column_names_collate = ['images', 'labels']
    
    		# 分割信息
    		if self.return_segments:
        		self.column_names_collate = ['images', 'labels', 'masks']
    		# 关键点信息
    		elif self.return_keypoints:
        		self.column_names_collate = ['images', 'labels', 'keypoints']
		# 不是训练模式
		else:
    		self.column_names_collate = ["images", "img_files", "hw_ori", "hw_scale", "pad"]

		# 读取数据
        try:
            f = []
            for p in self.path if isinstance(self.path, list) else [self.path]:
                p = Path(p)
                if p.is_dir():
                    f += glob.glob(str(p / "**" / "*.*"), recursive=True)
                elif p.is_file():
                    with open(p, "r") as t:
                        t = t.read().strip().splitlines()
                        parent = str(p.parent) + os.sep
                        f += [x.replace("./", parent) if x.startswith("./") else x for x in t]  
                else:
                    raise Exception(f"{p} does not exist")
            self.img_files = sorted([x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in self.img_formats])
            assert self.img_files, f"No images found"
        except Exception as e:
            raise Exception(f"Error loading data from {self.path}: {e}\n")

        # 检查缓存
        self.label_files = self._img2label_paths(self.img_files)  # 标签
        cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix(".cache.npy")  # 缓存标签
        if cache_path.is_file():
            cache, exists = np.load(cache_path, allow_pickle=True).item(), True  # 加载字典
            if cache["version"] == self.cache_version \
                    and cache["hash"] == self._get_hash(self.label_files + self.img_files):
                logger.info(f"Dataset Cache file hash/version check success.")
                logger.info(f"Load dataset cache from [{cache_path}] success.")
            else:
                logger.info(f"Dataset cache file hash/version check fail.")
                logger.info(f"Datset caching now...")
                cache, exists = self.cache_labels(cache_path), False  # 缓存
                logger.info(f"Dataset caching success.")
        else:
            logger.info(f"No dataset cache available, caching now...")
            cache, exists = self.cache_labels(cache_path), False  # 缓存
            logger.info(f"Dataset caching success.")

        # 显示缓存
        nf, nm, ne, nc, n = cache.pop("results")  # 找到、丢失、空的和损坏的数量
        if exists:
            d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
            tqdm(None, desc=d, total=n, initial=n)  # 显示缓存结果
        assert nf > 0 or not augment, f"No labels in {cache_path}. Can not train without labels. See {self.help_url}"

        # 读取缓存
        cache.pop("hash")  # 删除哈希值
        cache.pop("version")  # 删除版本号
        self.labels = cache['labels']
        self.img_files = [lb['im_file'] for lb in self.labels]  # 更新 im_files

        # 检查数据集是否全部为边界框或全部为分割
        lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in self.labels)
        len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
        if len_segments and len_boxes != len_segments:
            print(
                f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
                f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
                'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
            for lb in self.labels:
                lb['segments'] = []
        if len_cls == 0:
            raise ValueError(f'All labels empty in {cache_path}, can not start training without labels.')

        if single_cls:
            for x in self.labels:
                x['cls'][:, 0] = 0

        n = len(self.labels)  # 图像数量
        bi = np.floor(np.arange(n) / batch_size).astype(np.int_)  # 批次索引
        nb = bi[-1] + 1  # 批次数
        self.batch = bi  # 图像的批次索引

        # 将图像缓存到内存中以加快训练速度
        self.imgs, self.img_hw_ori, self.indices = None, None, range(n)

        # 矩形训练
        if self.rect:
            # 按纵横比排序
            s = self.img_shapes  # wh
            ar = s[:, 1] / s[:, 0]  # 纵横比
            irect = ar.argsort()
            self.img_files = [self.img_files[i] for i in irect]
            self.label_files = [self.label_files[i] for i in irect]
            self.labels = [self.labels[i] for i in irect]
            self.img_shapes = s[irect]  # wh
            ar = ar[irect]

            # 设置训练图像大小
            shapes = [[1, 1]] * nb
            for i in range(nb):
                ari = ar[bi == i]
                mini, maxi = ari.min(), ari.max()
                if maxi < 1:
                    shapes[i] = [maxi, 1]
                elif mini > 1:
                    shapes[i] = [1, 1 / mini]

            self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride

        self.imgIds = [int(Path(im_file).stem) for im_file in self.img_files]

通过注释可以了解初始化函数我们都做了哪些事情,init函数的意义在于读取数据集、将数据集加载进内存、设置图像大小等等。

下面我们来看看len函数和getitem函数:

def __getitem__(self, index):
    # 获取指定索引的样本
    sample = self.get_sample(index)

    # 遍历transforms_dict中的所有转换
    for _i, ori_trans in enumerate(self.transforms_dict):
        # 复制当前转换的字典
        _trans = ori_trans.copy()
        # 获取转换的函数名和概率
        func_name, prob = _trans.pop("func_name"), _trans.pop("prob", 1.0)
        # 'copy_paste'
        if func_name == 'copy_paste':
            sample = self.copy_paste(sample, prob)
        # 随机数小于概率
        elif random.random() < prob:
            # "albumentations"且self.albumentations为None
            if func_name == "albumentations" and getattr(self, "albumentations", None) is None:
                # 创建Albumentations对象
                self.albumentations = Albumentations(size=self.img_size, **_trans)
            # "letterbox"
            if func_name == "letterbox":
                new_shape = self.img_size if not self.rect else self.batch_shapes[self.batch[index]]
                sample = self.letterbox(sample, new_shape, **_trans)
            # 如果函数名是"mosaic"
            if func_name == "mosaic":
                # 对样本进行mosaic操作
                sample = self.mosaic(sample, **_trans)
            # 如果函数名是其他
            else:
                # 调用对应的函数对样本进行操作
                sample = getattr(self, func_name)(sample, **_trans)

    # 将样本的图像数据转换为连续的数组
    sample['img'] = np.ascontiguousarray(sample['img'])
    # 返回样本
    return sample

def __len__(self):
    return len(self.img_files)

在这个函数中,我们将图像变换方法放在字典中,根据value执行对应的函数,以此取代Pytorch源码中的超参读取的操作,一定程度上提升了代码的可读性。

最重要的,是dataloder的迁移:

def create_loader(
    dataset,
    batch_collate_fn,  # 处理每个批次的函数
    column_names_getitem,  # 获取数据集列的名称
    column_names_collate,  # 合并数据集列的名称
    batch_size,
    epoch_size=1,
    rank=0,
    rank_size=1,
    num_parallel_workers=8,  # 并行处理数据集的工作线程数
    shuffle=True,
    drop_remainder=False,  # 是否丢弃最后一个不完整的批次
    python_multiprocessing=False,  # 是否使用并行化
):
    # 设置随机种子
    de.config.set_seed(1236517205 + rank)
    # 获取CPU核心数
    cores = multiprocessing.cpu_count()
    # 计算并行工作线程数
    num_parallel_workers = min(int(cores / rank_size), num_parallel_workers)
    # 打印并行工作线程数
    logger.info(f"Dataloader num parallel workers: [{num_parallel_workers}]")
    # 根据设备数量创建数据集
    if rank_size > 1:
        ds = de.GeneratorDataset(
            dataset,
            column_names=column_names_getitem,
            num_parallel_workers=min(8, num_parallel_workers),
            shuffle=shuffle,
            python_multiprocessing=python_multiprocessing,
            num_shards=rank_size,
            shard_id=rank,
        )
    else:
        ds = de.GeneratorDataset(
            dataset,
            column_names=column_names_getitem,
            num_parallel_workers=min(32, num_parallel_workers),
            shuffle=shuffle,
            python_multiprocessing=python_multiprocessing,
        )
    # 对数据集进行批处理
    ds = ds.batch(
        batch_size, per_batch_map=batch_collate_fn,
        input_columns=column_names_getitem, output_columns=column_names_collate, drop_remainder=drop_remainder
    )
    ds = ds.repeat(epoch_size)

    return ds

在上面代码中我们可以看到与Pytorch的代码有着明显的区别:

  1. Mindspore不需要DDP sampler,与之相对应的,我们可以通过设置num_shardsshard_id来实现相似的功能
  2. MindSpore允许用户指定数据集的重复次数(epoch_size参数,通过ds.repeat方法实现),PyTorch的DataLoader中通常需要用户在训练循环中手动实现

其次我们注意到两个特殊的参数:column_names_getitemcolumn_names_collate,这两个参数是在dataset的init函数中进行的赋值,目的是明确指定哪些列应该在获取单个样本时返回,以及在组合成批次时应该返回哪些列。在Pytorch的Dataloader中无需这样的操作,因为在PyTorch中,数据集(Dataset)对象通常会返回一个样本,这个样本是一个包含了所有需要的字段(例如图像和标签)的元组或字典。然后,DataLoader会自动将这些样本组合成批次。

尝试从dataloader中取出的数据对比如下:

Pytorch

在这里插入图片描述
Pytorch版本代码
Mindspore
在这里插入图片描述
Mindspore版本代码

至此,我们的迁移告一段落了。

迁移过程中遇到的问题

  1. 刚开始尝试对api进行一一对应转换,注意到因为Dataloader和Generatedataset底层的实现和框架规划不一样,不能生硬迁移;
  2. Mindspore2.0和早期版本之间差别较大,有些api不一样,在互联网上搜寻资料时注意关注版本;
  3. Pillow库在Mindspore2.0的版本依赖中使用较低版本,注意降级;
  4. yolov5源码相对混乱,依赖很多,大家看的时候要仔细阅读。

写在最后

在上面的分析中我们实现了Yolov5从Pytorch到Mindspore的迁移。实际上在训练模型的整个过程中中,还有其他的地方需要迁移,比如模型的导入、流程的控制等,本案例仅仅提供了一个实战中数据集导入、处理的迁移思路。

如果用Mindspore重新实现Yolov5的算法,还可以有更多优化的空间,也有更多不同的实现方式,感兴趣的同学可以参考Mindyolo开源库,该库以及其简洁的代码实现了历代Yolo算法,以供大家参考。

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值