利用 Python 的包管理和动态属性获取(`__init__.py` 文件和 `getattr` 函数)特性来实现工厂方法模式

Python 提供了许多灵活的特性,例如包的 __init__.py 文件和 getattr 函数,这些特性可以帮助我们实现工厂方法模式来动态地创建不同类型的数据集实例。

1. 背景介绍

在深度学习项目中,我们通常需要处理多种类型的数据集,例如 COCO、Pascal VOC 和自定义的交通数据集。为了简化和统一数据集的加载过程,我们可以利用 Python 的包管理和动态属性获取特性来实现工厂方法模式。

  • 包的 __init__.py 文件:通过在包的 __init__.py 文件中导入模块,我们可以在初始化包时自动加载所有必要的类和函数。
  • getattr 函数getattr 函数允许我们动态地获取对象的属性或方法,这对于实现工厂方法模式非常有用,因为我们可以根据配置或输入动态地创建对象,而无需在代码中硬编码每种数据集的构建逻辑。

接下来,我们将通过具体的代码示例来展示如何使用这些特性来实现数据集的动态加载。

2. 模块和类的定义

在我们的项目中,数据集类被定义在 datasets 模块中。我们将定义一个 COCODataset 类,并在 datasets 模块的 __init__.py 文件中导入它。需要注意的是,COCODataset 只是众多数据集类中的一种,其他数据集类如 PascalVOCDatasetTrafficDataset 等也可以通过类似的方式定义和使用。

定义 COCODataset

datasets 模块中创建一个名为 coco.py 的文件,并定义 COCODataset 类。这个类继承自 torchvision.datasets.coco.CocoDetection,并添加了一些自定义逻辑。

# datasets/coco.py
import torchvision

class COCODataset(torchvision.datasets.coco.CocoDetection):
    def __init__(self, ann_file, root, remove_images_without_annotations, transforms=None):
        super(COCODataset, self).__init__(root, ann_file)
        # 自定义逻辑...
  • __init__ 方法COCODataset 类的构造函数接受 ann_file(注释文件路径)、root(图像根目录)、remove_images_without_annotations(是否移除没有注释的图像)和 transforms(图像变换)四个参数。这些参数与后面 DatasetCatalogget 方法返回的 args 对应。
  • 详细实现见附录
导入 COCODataset

datasets 模块的 __init__.py 文件中导入 COCODataset 类。这样可以确保在使用 datasets 模块时,所有数据集类都已加载。

# datasets/__init__.py
from .coco import COCODataset
from .voc import PascalVOCDataset
from .concat_dataset import ConcatDataset
from .traffic_dataset import TrafficDataset
from .carWinBiaoZhi_dataset import CarWinBiaoZhiDataset
from .carWinBiaoZhi_dataset_V2 import CarWinBiaoZhiDatasetV2
from .carWinBiaoZhi_dataset_V2_1 import CarWinBiaoZhiDatasetV2_1
from .GsData import CgTrafficData
from .GsData_xianQuan import CgTrafficDataWithXianQuan
from .GsData_1cls import CgTrafficData1Cls
from .GsData_ForSemi import CgTrafficDataSemi
from .GsRadarData import CgTrafficRadarData

__all__ = [
    "COCODataset", "ConcatDataset", "PascalVOCDataset", "TrafficDataset",
    "CarWinBiaoZhiDataset", "CarWinBiaoZhiDatasetV2", "CarWinBiaoZhiDatasetV2_1", 
    "CgTrafficData", "CgTrafficDataWithXianQuan", "CgTrafficDataSemi", 
    "CgTrafficRadarData", "CgTrafficData1Cls"
]

3. 使用 getattr 动态获取工厂方法

在构建数据集实例时,我们通过 getattr 函数动态获取工厂方法。以下是实现这一过程的核心代码:

# build_dataset.py
from . import datasets as D

def build_dataset(dataset_list, transforms, dataset_catalog, is_train=True):
    if not isinstance(dataset_list, (list, tuple)):
        raise RuntimeError(
            "dataset_list 应该是一个字符串列表,得到的是 {}".format(dataset_list)
        )
    
    datasets = []  # 初始化数据集列表
    
    for dataset_name in dataset_list:
        # 从 dataset_catalog 中获取数据集信息
        data = dataset_catalog.get(dataset_name)
        
        # 获取数据集的工厂方法
        factory = getattr(D, data["factory"])
        
        # 获取数据集的参数
        args = data["args"]
        
        # 设置数据集的变换
        args["transforms"] = transforms
        
        # 使用工厂方法创建数据集实例
        dataset = factory(**args)
        
        # 将创建的数据集添加到列表中
        datasets.append(dataset)
    
    # 如果是测试模式,返回数据集列表
    if not is_train:
        return datasets
    
    # 如果是训练模式,将所有数据集合并为一个数据集
    dataset = datasets[0]
    if len(datasets) > 1:
        dataset = D.ConcatDataset(datasets)
    
    return [dataset]

4. 数据集目录管理 (DatasetCatalog)

为了集中管理数据集的路径和相关信息,我们定义了 DatasetCatalog 类。这个类包含了所有数据集的配置信息,并提供了一个静态方法 get 来获取特定数据集的配置信息。

# paths_catalog.py
import os

class DatasetCatalog(object):
    DATA_DIR = "/home/Public_DataSets"
    DATASETS = {
        "coco_2017_train": {
            "img_dir": "coco/train2017",
            "ann_file": "coco/annotations/instances_train2017.json"
        },
        "voc_2007_train": {
            "data_dir": "voc/VOC2007",
            "split": "train"
        },
        # ... 其他数据集配置 ...
    }

    @staticmethod
    def get(name):
        if "coco" in name:
            data_dir = DatasetCatalog.DATA_DIR
            attrs = DatasetCatalog.DATASETS[name]
            args = dict(
                root=os.path.join(data_dir, attrs["img_dir"]),
                ann_file=os.path.join(data_dir, attrs["ann_file"]),
            )
            return dict(
                factory="COCODataset",
                args=args,
            )
        elif "voc" in name:
            data_dir = DatasetCatalog.DATA_DIR
            attrs = DatasetCatalog.DATASETS[name]
            args = dict(
                data_dir=os.path.join(data_dir, attrs["data_dir"]),
                split=attrs["split"],
            )
            return dict(
                factory="PascalVOCDataset",
                args=args,
            )
        # ... 其他数据集配置 ...
        raise RuntimeError("Dataset not available: {}".format(name))
说明

get 方法中,我们根据数据集名称动态生成配置字典。例如,对于 COCO 数据集:

return dict(
    factory="COCODataset",
    args=args,
)
  • factory:指定数据集类的名称,在后续步骤中用于动态获取工厂方法。
  • args:包含构建数据集实例所需的参数。

5. COCO 数据集的举例说明

假设我们有一个名为 "coco_2017_train" 的数据集,我们希望使用 DatasetCatalog 和工厂方法来加载这个数据集。以下是具体的步骤:

  1. 定义数据集配置

    # paths_catalog.py 中的 DATASETS 字典
    DATASETS = {
        "coco_2017_train": {
            "img_dir": "coco/train2017",
            "ann_file": "coco/annotations/instances_train2017.json"
        },
        # ... 其他数据集配置 ...
    }
    
  2. 获取数据集配置

    data = DatasetCatalog.get("coco_2017_train")
    
  3. 动态获取工厂方法

    factory = getattr(D, data["factory"])
    
  4. 创建数据集实例

    args = data["args"]
    args["transforms"] = some_transform_function  # 假设我们有一个变换函数
    dataset = factory(**args)
    

通过这种方式,我们可以动态地加载 COCO 数据集,而无需硬编码每种数据集的构建逻辑。这种设计模式提高了代码的灵活性和可维护性,使得数据集的管理和加载更加方便。

附录: COCODataset 类完整实现
# datasets/coco.py
import torch
import torchvision
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
from maskrcnn_benchmark.structures.keypoint import PersonKeypoints

min_keypoints_per_image = 10

def has_valid_annotation(anno):
    if len(anno) == 0:
        return False
    if all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno):
        return False
    if "keypoints" not in anno[0]:
        return True
    if sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) >= min_keypoints_per_image:
        return True
    return False

class COCODataset(torchvision.datasets.coco.CocoDetection):
    def __init__(self, ann_file, root, remove_images_without_annotations, transforms=None):
        super(COCODataset, self).__init__(root, ann_file)
        self.ids = sorted(self.ids)
        if remove_images_without_annotations:
            self.ids = [img_id for img_id in self.ids if has_valid_annotation(self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)))]
        self.categories = {cat['id']: cat['name'] for cat in self.coco.cats.values()}
        self.json_category_id_to_contiguous_id = {v: i + 1 for i, v in enumerate(self.coco.getCatIds())}
        self.contiguous_category_id_to_json_id = {v: k for k, v in self.json_category_id_to_contiguous_id.items()}
        self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
        self._transforms = transforms

    def __getitem__(self, idx):
        img, anno = super(COCODataset, self).__getitem__(idx)
        anno = [obj for obj in anno if obj["iscrowd"] == 0]
        boxes = [obj["bbox"] for obj in anno]
        boxes = torch.as_tensor(boxes).reshape(-1, 4)
        target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")
        classes = torch.tensor([self.json_category_id_to_contiguous_id[obj["category_id"]] for obj in anno])
        target.add_field("labels", classes)
        if anno and "segmentation" in anno[0]:
            masks = SegmentationMask([obj["segmentation"] for obj in anno], img.size, mode='poly')
            target.add_field("masks", masks)
        if anno and "keypoints" in anno[0]:
            keypoints = PersonKeypoints([obj["keypoints"] for obj in anno], img.size)
            target.add_field("keypoints", keypoints)
        target = target.clip_to_image(remove_empty=True)
        if self._transforms is not None:
            img, target = self._transforms(img, target)
        return img, target, idx

    def get_img_info(self, index):
        return self.coco.imgs[self.id_to_img_map[index]]
  • __init__ 方法:初始化数据集,加载注释,过滤无效注释,并设置类别和图像映射。
  • __getitem__ 方法:获取指定索引的图像和注释,应用可选的变换,并返回图像、目标和索引。
  • 14
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值