pytorch 多进程数据加载 - 序列化数据/serialize_data

背景介绍

OpenMMLab项目中构建数据集的基础类BaseDataset类的时候,对多进程数据加载有一个优化,叫做 ‘‘’序列化’。

先看代码部分



class BaseDataset(Dataset):
    r"""BaseDataset for open source projects in OpenMMLab.

    The annotation format is shown as follows.

    .. code-block:: none

        {
            "metainfo":
            {
              "dataset_type": "test_dataset",
              "task_name": "test_task"
            },
            "data_list":
            [
              {
                "img_path": "test_img.jpg",
                "height": 604,
                "width": 640,
                "instances":
                [
                  {
                    "bbox": [0, 0, 10, 20],
                    "bbox_label": 1,
                    "mask": [[0,0],[0,10],[10,20],[20,0]],
                    "extra_anns": [1,2,3]
                  },
                  {
                    "bbox": [10, 10, 110, 120],
                    "bbox_label": 2,
                    "mask": [[10,10],[10,110],[110,120],[120,10]],
                    "extra_anns": [4,5,6]
                  }
                ]
              },
            ]
        }

    Args:
        ann_file (str, optional): Annotation file path. Defaults to ''.
        metainfo (Mapping or Config, optional): Meta information for
            dataset, such as class information. Defaults to None.
        data_root (str, optional): The root directory for ``data_prefix`` and
            ``ann_file``. Defaults to ''.
        data_prefix (dict): Prefix for training data. Defaults to
            dict(img_path='').
        filter_cfg (dict, optional): Config for filter data. Defaults to None.
        indices (int or Sequence[int], optional): Support using first few
            data in annotation file to facilitate training/testing on a smaller
        serialize_data (bool, optional): Whether to hold memory using
            serialized objects, when enabled, data loader workers can use
            shared RAM from master process instead of making a copy. Defaults
            to True.
        pipeline (list, optional): Processing pipeline. Defaults to [].
        test_mode (bool, optional): ``test_mode=True`` means in test phase.
            Defaults to False.
        lazy_init (bool, optional): Whether to load annotation during
            instantiation. In some cases, such as visualization, only the meta
            information of the dataset is needed, which is not necessary to
            load annotation file. ``Basedataset`` can skip load annotations to
            save time by set ``lazy_init=True``. Defaults to False.
        max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
            None img. The maximum extra number of cycles to get a valid
            image. Defaults to 1000.

    Note:
        BaseDataset collects meta information from ``annotation file`` (the
        lowest priority), ``BaseDataset.METAINFO``(medium) and ``metainfo
        parameter`` (highest) passed to constructors. The lower priority meta
        information will be overwritten by higher one.

    Note:
        Dataset wrapper such as ``ConcatDataset``, ``RepeatDataset`` .etc.
        should not inherit from ``BaseDataset`` since ``get_subset`` and
        ``get_subset_`` could produce ambiguous meaning sub-dataset which
        conflicts with original dataset.

    Examples:
        >>> # Assume the annotation file is given above.
        >>> class CustomDataset(BaseDataset):
        >>>     METAINFO: dict = dict(task_name='custom_task',
        >>>                           dataset_type='custom_type')
        >>> metainfo=dict(task_name='custom_task_name')
        >>> custom_dataset = CustomDataset(
        >>>                      'path/to/ann_file',
        >>>                      metainfo=metainfo)
        >>> # meta information of annotation file will be overwritten by
        >>> # `CustomDataset.METAINFO`. The merged meta information will
        >>> # further be overwritten by argument `metainfo`.
        >>> custom_dataset.metainfo
        {'task_name': custom_task_name, dataset_type: custom_type}
    """

    METAINFO: dict = dict()
    _fully_initialized: bool = False

    def __init__(self,
                 ann_file: Optional[str] = '',
                 metainfo: Union[Mapping, Config, None] = None,
                 data_root: Optional[str] = '',
                 data_prefix: dict = dict(img_path=''),
                 filter_cfg: Optional[dict] = None,
                 indices: Optional[Union[int, Sequence[int]]] = None,
                 serialize_data: bool = True,
                 pipeline: List[Union[dict, Callable]] = [],
                 test_mode: bool = False,
                 lazy_init: bool = False,
                 max_refetch: int = 1000):
        self.ann_file = ann_file
        self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
        self.data_root = data_root
        self.data_prefix = copy.copy(data_prefix)
        self.filter_cfg = copy.deepcopy(filter_cfg)
        self._indices = indices
        self.serialize_data = serialize_data
        self.test_mode = test_mode
        self.max_refetch = max_refetch
        self.data_list: List[dict] = []
        self.data_bytes: np.ndarray

        # Join paths.
        self._join_prefix()

        # Build pipeline.
        self.pipeline = Compose(pipeline)
        # Full initialize the dataset.
        if not lazy_init:
            self.full_init()

    @force_full_init
    def get_data_info(self, idx: int) -> dict:
        """Get annotation by index and automatically call ``full_init`` if the
        dataset has not been fully initialized.

        序列化方式通过内存映射和反序列化,可能更适合处理大规模数据或减少内存占用,
        而非序列化方式则更简单直接,适用于数据规模较小或内存资源充足的情况。
        Args:
            idx (int): The index of data.

        Returns:
            dict: The idx-th annotation of the dataset.
            无论哪种方式,最后得到的 data_info 变量都包含了索引 idx 对应的数据。
            - 序列化数据加载时,通过地址计算、内存视图和反序列化,从字节数组中提取数据;
            - 非序列化数据加载时,直接从已存储的对象列表中复制所需数据。
            两种方式适应了不同的存储场景和性能需求。
        """
        if self.serialize_data:
            start_addr = 0 if idx == 0 else self.data_address[idx - 1].item()
            end_addr = self.data_address[idx].item()
            bytes = memoryview(
                self.data_bytes[start_addr:end_addr])  # type: ignore
            data_info = pickle.loads(bytes)  # type: ignore
        else:
            data_info = copy.deepcopy(self.data_list[idx])
        # Some codebase needs `sample_idx` of data information. Here we convert
        # the idx to a positive number and save it in data information.
        if idx >= 0:
            data_info['sample_idx'] = idx
        else:
            data_info['sample_idx'] = len(self) + idx

        return data_info

    def full_init(self):
        """Load annotation file and set ``BaseDataset._fully_initialized`` to
        True.

        If ``lazy_init=False``, ``full_init`` will be called during the
        instantiation and ``self._fully_initialized`` will be set to True. If
        ``obj._fully_initialized=False``, the class method decorated by
        ``force_full_init`` will call ``full_init`` automatically.

        Several steps to initialize annotation:

            - load_data_list: Load annotations from annotation file.
            - filter data information: Filter annotations according to
              filter_cfg.
            - slice_data: Slice dataset according to ``self._indices``
            - serialize_data: Serialize ``self.data_list`` if
              ``self.serialize_data`` is True.
        """
        
        # check是不是 self._fully_initialized 和 self.serialize_data 不能同时为 true
        if self._fully_initialized:
            return
        # load data information
        self.data_list = self.load_data_list()
        # filter illegal data, such as data that has no annotations.
        self.data_list = self.filter_data()
        # Get subset data according to indices.
        if self._indices is not None:
            self.data_list = self._get_unserialized_subset(self._indices)

        # serialize data_list
        if self.serialize_data:
            self.data_bytes, self.data_address = self._serialize_data()

        self._fully_initialized = True

BaseDataset类中定义了一些可能会影响内存使用的方法和属性,例如:

  • data_list:存储数据集所有样本的列表,每个样本都是一个字典,包含了图像路径、尺寸和实例信息等。
  • serialize_data:一个布尔值,指示是否在初始化时将data_list序列化以节省内存。当启用时,数据加载器的工作进程可以使用主进程的共享RAM,而不是进行复制。
  • _serialize_data和_get_serialized_subset:这些方法用于序列化和获取序列化数据的子集,这有助于在多进程数据加载时减少内存消耗。

在分布式训练中,如果每个GPU rank都加载完整的data_list,那么确实会导致内存的重复使用。为了解决这个问题,serialize_data属性被设置为True时,可以通过序列化数据来节省内存,这样每个工作进程就可以共享主进程的RAM,而不是各自复制一份数据。

serialize_data

在多进程数据加载的场景下,比如使用PyTorch的DataLoader时,每个工作进程(worker)通常需要加载数据集的一部分来并行处理。如果没有序列化处理,每个工作进程都会复制一份完整的data_list到自己的内存空间中,这会导致内存的大量重复使用,特别是在数据集很大的情况下。

通过serialize_data参数启用序列化后,数据集的样本信息会被转换成一个二进制格式的数组(data_bytes),并且每个样本信息的起始和结束位置会被记录在一个地址数组(data_address)中。这样,当数据加载器的工作进程需要获取数据时,它们可以直接从共享的data_bytes数组中按地址提取所需的样本信息,而无需复制整个数据列表。这意味着所有的工作进程都可以直接使用主进程中的共享内存,从而大大减少了内存的使用。

进一步理解 serialize data

用一个餐厅的比喻来理解serialize_data的概念。

你经营一家非常受欢迎的餐厅,这家餐厅的菜单上有100道菜。每天,你都需要为顾客提供这些菜,但是每道菜的需求量是不同的。为了高效地为顾客服务,你有两种选择:

  1. 不序列化(serialize_data=False
    这就像你在餐厅里为每个服务员准备一份完整的菜单,每份菜单上都有100道菜。每天早上,服务员们会从厨房领取他们需要的所有食材,准备一天的工作。这意味着每个服务员都需要携带大量的食材,而且厨房也需要准备足够的食材来满足所有服务员的需求。这在餐厅规模较小、顾客较少时是可行的,但如果餐厅很大,或者顾客非常多,这就会导致厨房的食材库存压力巨大,效率低下。

  2. 序列化(serialize_data=True
    现在,你决定改变策略。厨房不再为每个服务员准备一份完整的菜单,而是将每道菜的食材打包成单独的小包裹,并在每个包裹上贴上标签,说明这是哪道菜的食材。服务员们只需要根据顾客的订单来领取对应的食材包裹。这样,厨房只需要准备足够的食材来满足所有顾客的总需求,而不是每个服务员的需求。服务员们也不需要携带大量的食材,他们只需要根据需要领取相应的包裹即可。这种方式大大减少了食材的浪费和厨房的存储压力,提高了服务效率。

在数据集处理的上下文中,serialize_data的作用就像上述例子中的食材打包。如果没有序列化,每个工作进程(服务员)都需要一份完整的数据集副本(完整的菜单),这会导致大量的内存占用和数据重复。启用序列化后,数据集的每个样本都被打包成一个二进制格式的“包裹”(data_bytes),并附有一个地址标签(data_address),工作进程只需要根据需要加载和处理这些“包裹”,而不是整个数据集,这样可以显著减少内存的使用,提高数据处理的效率。

  • 15
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值