1.使用detr训练自己的数据集时,先修改要检测的类别数,首先在detr.py文件中修改下面的内容:
def build(args):
# the `num_classes` naming here is somewhat misleading.
# it indeed corresponds to `max_obj_id + 1`, where max_obj_id
# is the maximum id for a class in your dataset. For example,
# COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
# As another example, for a dataset that has a single class with id 1,
# you should pass `num_classes` to be 2 (max_obj_id + 1).
# For more details on this, check the following discussion
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
#num_classes = 20 if args.dataset_file != 'coco' else 91
num_classes = 1 if args.dataset_file != 'coco' else 1
if args.dataset_file == "coco_panoptic":
# for panoptic, we just add a num_classes that is large enough to hold
# max_obj_id + 1, but the exact value doesn't really matter
#num_classes = 250
num_classes = 2
device = torch.device(args.device)
上面的num_classs要检测几个类别就写几个类别,if语句中的是要加上一个背景类别。
2.然后运行下面的脚本:权重先去下载好,或者自己预训练的权重文件,
num_class=检测类别数目+1
import torch
pretrained_weights = torch.load('./detr-r50-e632da11.pth')
num_class =2
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1,256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)
torch.save(pretrained_weights,"detr-r50_line.pth")
3.设置超参数,然后开始训练。
遇见的bug:
第一个时解码问题:
Traceback (most recent call last):
File "E:\GoogleDownload\detr-main\main.py", line 249, in <module>
main(args)
File "E:\GoogleDownload\detr-main\main.py", line 142, in main
dataset_train = build_dataset(image_set='train', args=args)
File "E:\GoogleDownload\detr-main\datasets\__init__.py", line 20, in build_dataset
return build_coco(image_set, args)
File "E:\GoogleDownload\detr-main\datasets\coco.py", line 157, in build
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks)
File "E:\GoogleDownload\detr-main\datasets\coco.py", line 19, in __init__
super(CocoDetection, self).__init__(img_folder, ann_file)
File "E:\Anaconda3\envs\detr-main\lib\site-packages\torchvision\datasets\coco.py", line 36, in __init__
self.coco = COCO(annFile)
File "E:\Anaconda3\envs\detr-main\lib\site-packages\pycocotools\coco.py", line 82, in __init__
dataset = json.load(f)
File "E:\Anaconda3\envs\detr-main\lib\json\__init__.py", line 293, in load
return loads(fp.read(),
UnicodeDecodeError: 'gbk' codec can't decode byte 0xaf in position 295131: illegal multibyte sequence
解决方法:在E:\Anaconda3\envs\detr-main\Lib\site-packages\pycocotools\coco.py文件里面加上encoding='utf-8'.
# load dataset
self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
if not annotation_file == None:
print('loading annotations into memory...')
tic = time.time()
with open(annotation_file, 'r',encoding='UTF-8') as f:
dataset = json.load(f)
assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
print('Done (t={:0.2f}s)'.format(time.time()- tic))
self.dataset = dataset
self.createIndex()
第二个是读取文件路径文件名错误:
Traceback (most recent call last):
File "E:\GoogleDownload\detr-main\main.py", line 249, in <module>
main(args)
File "E:\GoogleDownload\detr-main\main.py", line 197, in main
train_stats = train_one_epoch(
File "E:\GoogleDownload\detr-main\engine.py", line 28, in train_one_epoch
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
File "E:\GoogleDownload\detr-main\util\misc.py", line 223, in log_every
for obj in iterable:
File "E:\Anaconda3\envs\detr-main\lib\site-packages\torch\utils\data\dataloader.py", line 633, in __next__
data = self._next_data()
File "E:\Anaconda3\envs\detr-main\lib\site-packages\torch\utils\data\dataloader.py", line 1345, in _next_data
return self._process_data(data)
File "E:\Anaconda3\envs\detr-main\lib\site-packages\torch\utils\data\dataloader.py", line 1371, in _process_data
data.reraise()
File "E:\Anaconda3\envs\detr-main\lib\site-packages\torch\_utils.py", line 644, in reraise
raise exception
FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "E:\Anaconda3\envs\detr-main\lib\site-packages\torch\utils\data\_utils\worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "E:\Anaconda3\envs\detr-main\lib\site-packages\torch\utils\data\_utils\fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "E:\Anaconda3\envs\detr-main\lib\site-packages\torch\utils\data\_utils\fetch.py", line 51, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "E:\GoogleDownload\detr-main\datasets\coco.py", line 24, in __getitem__
img, target = super(CocoDetection, self).__getitem__(idx)
File "E:\Anaconda3\envs\detr-main\lib\site-packages\torchvision\datasets\coco.py", line 48, in __getitem__
image = self._load_image(id)
File "E:\Anaconda3\envs\detr-main\lib\site-packages\torchvision\datasets\coco.py", line 41, in _load_image
return Image.open(os.path.join(self.root, path)).convert("RGB")
File "E:\Anaconda3\envs\detr-main\lib\site-packages\PIL\Image.py", line 3131, in open
fp = builtins.open(filename, "rb")
FileNotFoundError: [Errno 2] No such file or directory: 'E:\\GoogleDownload\\detr-main\\datasets\\nestdataset\\train2017\\RicardoQuanbu20181143'
解决方法:在E:\Anaconda3\envs\detr-main\Lib\site-packages\PIL\Image.py里面的3131行加上图片文件的后缀名,我的图片文件是jpg类型,所以加上.jpg
if filename:
fp = builtins.open(filename+'.jpg', "rb")
exclusive_fp = True