Pytorch torchvision构建Faster-rcnn(一)----coco数据读取

torchvision在更新到0.3版本后,增添了很多新的功能,其中就包括整个目标检测算法/分割算法模块。这里打算将Faster-rcnn代码从torchvision分离出来,并分组件记录下Pytorch torchvision官方实现的Faster-rcnn代码并做记录和注释。

注:各个模块的代码均可以在0.3.0+版本的torchvision中找到

torchvision github地址:https://github.com/pytorch/vision

目录

数据读取

准备工作

 读取coco数据集

创建自己的data transform

定义ConvertCocoPolysToMask

定义Compose

定义RandomHorizontalFlip

定义ToTensor

使用自己定义的transforms


数据读取

准备工作

首先要确保torchvision升级到0.3.0版本以上,另外加载coco数据集需要用到pycocotools api,需要提前安装。没有安装的可以利用下述命令安装:

# 安装Cython
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple Cython

# 安装pycocotools
pip install git+https://github.com/philferriere/cocoapi.git#subdirectory=PythonAPI

 读取coco数据集

torchvision中实现了coco数据集的读取api CocoDetection,其定义如下:

torchvision.datasets.CocoDetection(root, annFile, transform=None, target_transform=None, transforms=None)

其中参数定义:

  • root : coco图片路径
  • annFile : 标注文件路径
  • transform : 图像转换(用于PIL)
  • target_transform : 标注转换
  • transforms : 图像和标注的转换

一个简单的pytorch coco dataloader创建及可视化程序如下:

import torch
import torchvision
import torchvision.transforms as T
import torchvision.datasets as datasets
from torchvision.transforms import functional as F
import cv2
import random

font = cv2.FONT_HERSHEY_SIMPLEX

root = '/public/yzy/coco/2014/train2014/'
annFile = '/public/yzy/coco/2014/annotations/instances_train2014.json'

# 定义 coco collate_fn
def collate_fn_coco(batch):
    return tuple(zip(*batch))

# 创建 coco dataset
coco_det = datasets.CocoDetection(root,annFile,transform=T.ToTensor())
# 创建 Coco sampler
sampler = torch.utils.data.RandomSampler(coco_det)
batch_sampler = torch.utils.data.BatchSampler(sampler, 8, drop_last=True)

# 创建 dataloader
data_loader = torch.utils.data.DataLoader(
        coco_det, batch_sampler=batch_sampler, num_workers=3,
        collate_fn=collate_fn_coco)

# 可视化
for imgs,labels in data_loader:
    for i in range(len(imgs)):
        bboxes = []
        ids = []
        img = imgs[i]
        labels_ = labels[i]
        for label in labels_:
            bboxes.append([label['bbox'][0],
            label['bbox'][1],
            label['bbox'][0] + label['bbox'][2],
            label['bbox'][1] + label['bbox'][3]
            ])
            ids.append(label['category_id'])

        img = img.permute(1,2,0).numpy()
        img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
        for box ,id_ in zip(bboxes,ids):
            x1 = int(box[0])
            y1 = int(box[1])
            x2 = int(box[2])
            y2 = int(box[3])
            cv2.rectangle(img,(x1,y1),(x2,y2),(0,0,255),thickness=2)
            cv2.putText(img, text=str(id_), org=(x1 + 5, y1 + 5), fontFace=font, fontScale=1, 
                thickness=2, lineType=cv2.LINE_AA, color=(0, 255, 0))
        cv2.imshow('test',img)
        cv2.waitKey()

效果:

注意由于coco中的图片大小不一致,因此需要使用BatchSampler并重新定义collate_fn。

创建自己的data transform

注意CocoDetection返回的结果有img和target两项,而torchvision.transform中只实现了对img的transform,因此对于target的transform我们需要自己实现。

定义ConvertCocoPolysToMask

ConvertCocoPolysToMask将每个image对应的target转化成为一个dict,这个dict中保存了该图片的所有标注信息,其中对于目标检测的有用信息是boxes和labels,其如下:

target :{

'boxes' : 包含所有box的Tensor,大小为N×4,其中N为图像中包含目标的总个数

'labels' : 包含所有目标的类别信息的Tensor,长度为N

...

}

class ConvertCocoPolysToMask(object):
    def __call__(self, image, target):
        w, h = image.size

        image_id = target["image_id"]
        image_id = torch.tensor([image_id])

        anno = target["annotations"]

        anno = [obj for obj in anno if obj['iscrowd'] == 0]

        boxes = [obj["bbox"] for obj in anno]
        # guard against no boxes via resizing
        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
        boxes[:, 2:] += boxes[:, :2]
        boxes[:, 0::2].clamp_(min=0, max=w)
        boxes[:, 1::2].clamp_(min=0, max=h)

        classes = [obj["category_id"] for obj in anno]
        classes = torch.tensor(classes, dtype=torch.int64)

        segmentations = [obj["segmentation"] for obj in anno]
        masks = convert_coco_poly_to_mask(segmentations, h, w)

        keypoints = None
        if anno and "keypoints" in anno[0]:
            keypoints = [obj["keypoints"] for obj in anno]
            keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
            num_keypoints = keypoints.shape[0]
            if num_keypoints:
                keypoints = keypoints.view(num_keypoints, -1, 3)

        keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
        boxes = boxes[keep]
        classes = classes[keep]
        masks = masks[keep]
        if keypoints is not None:
            keypoints = keypoints[keep]

        target = {}
        target["boxes"] = boxes
        target["labels"] = classes
        target["masks"] = masks
        target["image_id"] = image_id
        if keypoints is not None:
            target["keypoints"] = keypoints

        # for conversion to coco api
        area = torch.tensor([obj["area"] for obj in anno])
        iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
        target["area"] = area
        target["iscrowd"] = iscrowd

        return image, target

定义Compose

compose用于执行对image和target的transform,对应于torchvision.transforms中的Compose()类

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

定义RandomHorizontalFlip

RandomHorizontalFlip实现对image和target的水平翻转

class RandomHorizontalFlip(object):
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
        return image, target

定义ToTensor

ToTensor将image转化为Pytorch Tensor类型:

class ToTensor(object):
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

使用自己定义的transforms

我们将上述函数定义在transforms.py中,就可以使用自己定义的transfrom了。在使用的时候只需要将coco_dataset的定义改为如下即可:

import transforms as T
coco_det = datasets.CocoDetection(root,annFile,
    transform=T.Compose([
        T.ConvertCocoPolysToMask(),
        T.ToTensor(),
        T.RandomHorizontalFlip(0.5)
        
]))

至此,我们就完成了通过torchvision读取coco数据集的步骤。

  • 19
    点赞
  • 80
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值