【代码整理】基于COCO格式的pytorch Dataset类实现

【更新】
1.Transform添加了测试时增强.(推理一张图像)
2.COCODataset添加了map变量.(COCO categories字段id未按照顺序,将其按顺序映射)
3.COCODataset添加了filterImgIds变量和filterImgById()方法.(COCO数据集中有些图像没有标注,过滤掉这些图像)
4.filterImgById()同时过滤训练集中’000000200365.jpg’, '000000550395.jpg’这两张图像,因为存在bbox的w或h=0

import模块

import numpy as np
import torch
from functools import partial
from PIL import Image
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import random
import albumentations as A
from pycocotools.coco import COCO
import os
import cv2
import matplotlib.pyplot as plt

基于albumentations库自定义数据预处理/数据增强

class Transform():
    '''数据预处理/数据增强(基于albumentations库)
    '''
    def __init__(self, imgSize):
        maxSize = max(imgSize[0], imgSize[1])
        # 训练时增强
        self.trainTF = A.Compose([
                A.BBoxSafeRandomCrop(p=0.5),
                # 最长边限制为imgSize
                A.LongestMaxSize(max_size=maxSize),
                # 随机翻转
                A.HorizontalFlip(p=0.5),
                # 参数:随机色调、饱和度、值变化
                A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, always_apply=False, p=0.5),
                # 随机明亮对比度
                A.RandomBrightnessContrast(p=0.2),   
                # 高斯噪声
                A.GaussNoise(var_limit=(0.05, 0.09), p=0.4),     
                A.OneOf([
                    # 使用随机大小的内核将运动模糊应用于输入图像
                    A.MotionBlur(p=0.2),   
                    # 中值滤波
                    A.MedianBlur(blur_limit=3, p=0.1),    
                    # 使用随机大小的内核模糊输入图像
                    A.Blur(blur_limit=3, p=0.1),  
                ], p=0.2),
                # 较短的边做padding
                A.PadIfNeeded(imgSize[0], imgSize[1], border_mode=cv2.BORDER_CONSTANT, value=[0,0,0]),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ],
            bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.0, label_fields=['category_ids']),
            )
        # 验证时增强
        self.validTF = A.Compose([
                # 最长边限制为imgSize
                A.LongestMaxSize(max_size=maxSize),
                # 较短的边做padding
                A.PadIfNeeded(imgSize[0], imgSize[1], border_mode=0, mask_value=[0,0,0]),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ],
            bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.0, label_fields=['category_ids']),
            )
        # 测试时增强
        self.testTF = A.Compose([
                # 最长边限制为imgSize
                A.LongestMaxSize(max_size=maxSize),
                # 较短的边做padding
                A.PadIfNeeded(imgSize[0], imgSize[1], border_mode=0, mask_value=[0,0,0]),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ])

自定义数据集读取类COCODataset实现



class COCODataset(Dataset):

    def __init__(self, annPath, imgDir, inputShape=[800, 600], trainMode=True, map=None):
        '''__init__() 为默认构造函数,传入数据集类别(训练或测试),以及数据集路径

        Args:
            :param annPath:     COCO annotation 文件路径
            :param imgDir:      图像的根目录
            :param inputShape: 网络要求输入的图像尺寸
            :param trainMode:   训练集/测试集
            :param trainMode:   训练集/测试集

        Returns:
            FRCNNDataset
        '''      
        self.mode = trainMode
        self.tf = Transform(imgSize=inputShape)
        self.imgDir = imgDir
        self.annPath = annPath
        # 为实例注释初始化COCO的API
        self.coco=COCO(annPath)
        # 获取数据集中所有图像对应的imgId
        self.imgIds = self.coco.getImgIds()
        # 如果标签的id正好不是按顺序来的,还需进行映射
        self.map = map
        '''过滤掉那些没有框的图像,很重要!!!'''
        self.filterImgIds = self.filterImgById()
                

    def __len__(self):
        '''重载data.Dataset父类方法, 返回数据集大小
        '''
        return len(self.filterImgIds)
    
    def __getitem__(self, index):
        '''重载data.Dataset父类方法, 获取数据集中数据内容
           这里通过pycocotools来读取图像和标签
        '''   
        # 通过imgId获取图像信息imgInfo: 例:{'id': 12465, 'license': 1, 'height': 375, 'width': 500, 'file_name': '2011_003115.jpg'}
        imgId = self.filterImgIds[index]
        imgInfo = self.coco.loadImgs(imgId)[0]
        # 载入图像 (通过imgInfo获取图像名,得到图像路径)               
        image = Image.open(os.path.join(self.imgDir, imgInfo['file_name']))
        image = np.array(image.convert('RGB'))
        # 得到图像里包含的BBox的所有id
        imgAnnIds = self.coco.getAnnIds(imgIds=imgId)   
        # 通过BBox的id找到对应的BBox信息
        anns = self.coco.loadAnns(imgAnnIds) 
        # 获取BBox的坐标和类别
        labels, boxes = [], []
        for ann in anns:
            labelName = ann['category_id']
            labels.append(labelName)
            boxes.append(ann['bbox'])
        labels = np.array(labels)
        boxes = np.array(boxes)
        
        # 训练/验证时的数据增强各不相同
        if(self.mode):
            # albumentation的图像维度得是[W,H,C]
            transformed = self.tf.trainTF(image=image, bboxes=boxes, category_ids=labels)
        else:
            transformed = self.tf.validTF(image=image, bboxes=boxes, category_ids=labels)
        # 这里的box是coco格式(xywh)
        image, box, label = transformed['image'], transformed['bboxes'], transformed['category_ids']
        if self.map != None:
            label = [self.map[i] for i in label]
        # 再把coco格式转成VOC格式(x0, y0, x1, y1):
        box = [[b[0], b[1], b[2]+b[0], b[3]+b[1]] for b in box]
        return image.transpose(2,0,1), np.array(box), np.array(label)
    


    def filterImgById(self):
        '''过滤掉那些没标注的图像
        '''
        print('filtering no objects images...')
        filterImgIds = []
        for i in tqdm(range(len(self.imgIds))):
            # 获取图像信息(json文件 "images" 字段)
            imgInfo = self.coco.loadImgs(self.imgIds[i])[0]
            # 得到当前图像里包含的BBox的所有id
            annIds = self.coco.getAnnIds(imgIds=imgInfo['id'])
            # anns (json文件 "annotations" 字段)
            anns = self.coco.loadAnns(annIds)
            if len(anns)!=0:
                # 专门针对COCO数据集,这两张图片存在bbox的w或h=0的情况:
                if imgInfo['file_name'] not in ['000000200365.jpg', '000000550395.jpg']:
                    filterImgIds.append(self.imgIds[i])
        return filterImgIds

其他


# DataLoader中collate_fn参数使用
# 由于检测数据集每张图像上的目标数量不一
# 因此需要自定义的如何组织一个batch里输出的内容
def frcnn_dataset_collate(batch):
    images = []
    bboxes = []
    labels = []
    for img, box, label in batch:
        images.append(img)
        bboxes.append(box)
        labels.append(label)
    images = torch.from_numpy(np.array(images))
    return images, bboxes, labels



# 设置Dataloader的种子
# DataLoader中worker_init_fn参数使
# 为每个 worker 设置了一个基于初始种子和 worker ID 的独特的随机种子, 这样每个 worker 将产生不同的随机数序列,从而有助于数据加载过程的随机性和多样性
def worker_init_fn(worker_id, seed):
    worker_seed = worker_id + seed
    random.seed(worker_seed)
    np.random.seed(worker_seed)
    torch.manual_seed(worker_seed)


# 固定全局随机数种子
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

batch数据集可视化


def visBatch(dataLoader:DataLoader):
    '''可视化训练集一个batch
    Args:
        dataLoader: torch的data.DataLoader
    Retuens:
        None     
    '''
    # COCO
    catName = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
               'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
               'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
               'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
               'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
               'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
               'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
               'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
               'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
               'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
               'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
               'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
               'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
               'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
    # VOC0712
    for step, batch in enumerate(dataLoader):
        images, boxes, labels = batch[0], batch[1], batch[2]
        # 只可视化一个batch的图像:
        if step > 0: break
        # 图像均值
        mean = np.array([0.485, 0.456, 0.406]) 
        # 标准差
        std = np.array([[0.229, 0.224, 0.225]]) 
        plt.figure(figsize = (8,8))
        for idx, imgBoxLabel in enumerate(zip(images, boxes, labels)):
            img, box, label = imgBoxLabel
            ax = plt.subplot(4,4,idx+1)
            img = img.numpy().transpose((1,2,0))
            # 由于在数据预处理时我们对数据进行了标准归一化,可视化的时候需要将其还原
            img = img * std + mean
            for instBox, instLabel in zip(box, label):
                x0, y0, x1, y1 = round(instBox[0]),round(instBox[1]), round(instBox[2]), round(instBox[3])
                # 显示框
                ax.add_patch(plt.Rectangle((x0, y0), x1-x0, y1-y0, color='blue', fill=False, linewidth=2))
                # 显示类别
                ax.text(x0, y0, catName[instLabel], bbox={'facecolor':'white', 'alpha':0.5})
            plt.imshow(img)
            # 在图像上方展示对应的标签
            # 取消坐标轴
            plt.axis("off")
             # 微调行间距
            plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.05, hspace=0.05)
        plt.show()

example

    # 固定随机种子
    seed = 22
    seed_everything(seed)
    # BatcchSize
    BS = 16
    # 图像尺寸
    imgSize = [800, 800]
    '''COCO'''
    trainAnnPath = "E:/datasets/Universal/COCO2017/COCO/annotations/instances_train2017.json"
    testAnnPath = "E:/datasets/Universal/COCO2017/COCO/annotations/instances_val2017.json"
    imgDir =  "E:/datasets/Universal/COCO2017/COCO/train2017"
    cls_num = 80
    map = {1:0, 2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7, 9:8, 10:9, 11:10, 13:11, 14:12, 15:13, 16:14, 17:15, 18:16, 19:17, 20:18, 21:19, 22:20, 23:21, 
           24:22, 25:23, 27:24, 28:25, 31:26, 32:27, 33:28, 34:29, 35:30, 36:31, 37:32, 38:33, 39:34, 40:35, 41:36, 42:37, 43:38, 44:39, 46:40, 
           47:41, 48:42, 49:43, 50:44, 51:45, 52:46, 53:47, 54:48, 55:49, 56:50, 57:51, 58:52, 59:53, 60:54, 61:55, 62:56, 63:57, 64:58, 65:59, 
           67:60, 70:61, 72:62, 73:63, 74:64, 75:65, 76:66, 77:67, 78:68, 79:69, 80:70, 81:71, 82:72, 84:73, 85:74, 86:75, 87:76, 88:77, 89:78, 90:79}
    # 自定义数据集读取类
    trainDataset = COCODataset(trainAnnPath, imgDir, imgSize, trainMode=True, map=map)
    trainDataLoader = DataLoader(trainDataset, shuffle=True, batch_size=BS, num_workers=2, pin_memory=True,
                                    collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))
    # validDataset = COCODataset(testAnnPath, imgDir, imgSize, trainMode=False)
    # validDataLoader = DataLoader(validDataset, shuffle=True, batch_size=BS, num_workers = 1, pin_memory=True, 
                                  # collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))



    print(f'训练集大小 : {trainDataset.__len__()}')
    # visBatch(trainDataLoader)
    cnt = 0
    for step, batch in enumerate(trainDataLoader):
        images, boxes, labels = batch[0], batch[1], batch[2]
        # torch.Size([bs, 3, 800, 800])
        print(f'images.shape : {images.shape}')   
        # 列表形式,因为每个框里的实例数量不一,所以每个列表里的box数量不一
        print(f'len(boxes) : {len(boxes)}')     
        # 列表形式,因为每个框里的实例数量不一,所以每个列表里的label数量不一  
        print(f'len(labels) : {len(labels)}')     
        break

输出

在这里插入图片描述

loading annotations into memory...
Done (t=11.97s)
creating index...
index created!
filtering no objects images...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 118287/118287 [00:00<00:00, 164053.90it/s] 
训练集大小 : 117264
images.shape : torch.Size([16, 3, 800, 800])
len(boxes) : 16
len(labels) : 16
  • 8
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch 提供了许多内置的数据集,可以用于训练和评估模型。这些数据集可以通过 torchvision 库进行访问和加载。以下是一些常见的 PyTorch 数据集示例: 1. MNIST: 手写数字图像数据集。可以使用 `torchvision.datasets.MNIST` 加载。 ```python from torchvision import datasets train_dataset = datasets.MNIST(root='./data', train=True, download=True) test_dataset = datasets.MNIST(root='./data', train=False, download=True) ``` 2. CIFAR-10: 包含10个别的彩色图像数据集。可以使用 `torchvision.datasets.CIFAR10` 加载。 ```python from torchvision import datasets train_dataset = datasets.CIFAR10(root='./data', train=True, download=True) test_dataset = datasets.CIFAR10(root='./data', train=False, download=True) ``` 3. ImageNet: 大型图像数据集,包含超过一百万张带有标签的图像。可以使用 `torchvision.datasets.ImageNet` 加载,但这需要额外的设置和准备工作。 ```python from torchvision import datasets # 示例代码仅展示了加载方式,实际使用需要额外配置 train_dataset = datasets.ImageNet(root='./data', split='train') test_dataset = datasets.ImageNet(root='./data', split='val') ``` 除了这些内置数据集,PyTorch 还提供了其他许多数据集,例如 COCO、VOC 等。此外,你也可以创建自己的自定义数据集,继承 `torch.utils.data.Dataset` 实现 `__len__` 和 `__getitem__` 方法。 希望这些信息对你有帮助!如果有任何其他问题,请随时提问。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值