本文用于记录利用Pytorch官方目标检测模型(torchvision.models.detection.xx)训练自定义数据集的多GPU并行训练爬坑过程,不具备普适性排雷功能,代码规范度较差,主要用于个人记录回溯。本文基于:Pytorch官方fastrcnn的Tutorial以及中文博客完成。
目录
-
环境
win10-homebasic(Linux);Pytorch 1.4.0;torchvision 0.5.0; Python 3.6.5/3.8.5;OpenCV 4.2.0 ;CUDA 10.0
尝试过Pytorch 1.7.0+torchvision 0.7.0 + CUDA 10.2,会报nms()函数没定义的错误。
-
数据集
自定义数据集由labelImg生成的Pascal格式标注集,格式如下图所示。Pytorch中torchvision.model可能是仅接受COCO格式,因此本文手写了格式转换的代码。官方Tutorial中有一点需要注意,maskrcnn输入的target必须包含<boxes>,<labels>和<mask>。而官方说明中<mask>是optional。本文中选择fasterrcnn,只需要包含前两个。
-
数据处理代码
数据处理代码基于torchvision.datasets.VOCDetection(path)改写。有一些需要注意的地方。首先,__ini__()函数中我加了一段判断:
if os.path.splitext(i)[1] != ".jpg":
idx = file_names.index(i)
file_names.__delitem__(idx)
这是由于我在Jupyter中查看‘JPEGImages’文件夹中有300个.jpg文件,而读取时多了一个xx_checkpoint文件。因此多加了一步检查。
重点在__getitem__()函数中的target_form_trans()。它的功能是将parse_voc_xml得到的xml信息转换成coco格式(?我不确定这么写对不对,更严格的说是符合fasterrcnn输入的格式)。parse_voc_xml得到了一个以annotation为唯一key的dict。而annotation这个key的键值是我们需要的标注信息。因此target_form_trans第一行先把annotation的键值作为新的字典进行处理。接下来最关键的标注信息存放在dic的key=‘object’的键值中。这里面最大的一个坑是如果一幅图像有多个标注框,就会在dic这个字典中产生多个key=‘object’的内容。而字典是不允许key重名的,因此要考虑如何找到所有的object。如果object有多个,就需要通过以下语句:
if isinstance(dic['object'],list):
num_objs = len(dic['object'])
elif isinstance(dic['object'],dict):
num_objs = 1
else:
num_objs = 0
来判断object的个数。因为多个object时,dic['object']的格式为list,而单个object时格式为dic。因此仅仅用len(dic['object'])会在object数量为1时返回其字典的key数量,造成计数错误。其余不表,至于<area>和<iscrowd>是复旦人群数据库用到的,这里用不到,删不删随意。
import os
import torch
import sys
import tarfile
import collections
import numpy as np
from torchvision.datasets.vision import VisionDataset
from torchvision.datasets.voc import VOCDetection
if sys.version_info[0] == 2:
import xml.etree.cElementTree as ET
else:
import xml.etree.ElementTree as ET
from PIL import Image
from torchvision.datasets.utils import download_url, check_integrity, verify_str_arg
# 用于在图像中显示的class label
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'cover', 'uncover', 'other' ,'empty'
]
class VOCDataset(VisionDataset):
def __init__(self,
root,
transforms=None,
transform=None,
target_transform=None):
super(VOCDataset, self).__init__(root, transforms, transform, target_transform)
voc_root = self.root
self.image_dir = os.path.join(voc_root, 'JPEGImages') #数据集中图片放在'JPEGImages'文件夹
self.annotation_dir = os.path.join(voc_root, 'Annotations') #数据集中标注xml放在'Annotations'文件夹
if not os.path.isdir(voc_root):
raise RuntimeError('Dataset not found or corrupted.')
#用于存放train和test标签的ImageSets文件夹,存放的是txt文件,暂时不用
#splits_dir = os.path.join(voc_root, 'ImageSets/Main')
#split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
#with open(os.path.join(split_f), "r") as f:
# file_names = [x.strip() for x in f.readlines()]
#加载图片和标注
file_names = os.listdir(self.image_dir,)
for i in file_names: # 循环读取路径下的文件并筛选输出
if os.path.splitext(i)[1] != ".jpg":
idx = file_names.index(i)
file_names.__delitem__(idx)
file_names = [x.split(".",1)[0] for x in file_names] #分离出.jpg
self.images = [os.path.join(self.image_dir, x + ".jpg") for x in file_names]
self.annotations = [os.path.join(self.annotation_dir, x + ".xml") for x in file_names]
assert (len(self.images) == len(self.annotations))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a dictionary of the XML tree.
"""
img = Image.open(self.images[index]).convert('RGB')
dic = self.parse_voc_xml(
ET.parse(self.annotations[index]).getroot())
target = self.target_form_trans(dic,index)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def target_form_trans(self,dic,index):
dic = dic['annotation']
# 处理image_id字段
image_id = torch.tensor([index])
target = {}
target["boxes"] = []
target["labels"] = []
target["image_id"] = image_id
target["area"] = []
target["iscrowd"] = []
# 处理object中包含的字段
boxes = []
labels =