DETR训练记录

1.使用detr训练自己的数据集时,先修改要检测的类别数,首先在detr.py文件中修改下面的内容:

def build(args):
    # the `num_classes` naming here is somewhat misleading.
    # it indeed corresponds to `max_obj_id + 1`, where max_obj_id
    # is the maximum id for a class in your dataset. For example,
    # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
    # As another example, for a dataset that has a single class with id 1,
    # you should pass `num_classes` to be 2 (max_obj_id + 1).
    # For more details on this, check the following discussion
    # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
    #num_classes = 20 if args.dataset_file != 'coco' else 91
    num_classes = 1 if args.dataset_file != 'coco' else 1
    if args.dataset_file == "coco_panoptic":
        # for panoptic, we just add a num_classes that is large enough to hold
        # max_obj_id + 1, but the exact value doesn't really matter
        #num_classes = 250
        num_classes = 2
    device = torch.device(args.device)

上面的num_classs要检测几个类别就写几个类别,if语句中的是要加上一个背景类别。

2.然后运行下面的脚本:权重先去下载好,或者自己预训练的权重文件,

num_class=检测类别数目+1

import torch
pretrained_weights = torch.load('./detr-r50-e632da11.pth')

num_class =2
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1,256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)
torch.save(pretrained_weights,"detr-r50_line.pth")

3.设置超参数,然后开始训练。

遇见的bug:

第一个时解码问题:

Traceback (most recent call last):
  File "E:\GoogleDownload\detr-main\main.py", line 249, in <module>
    main(args)
  File "E:\GoogleDownload\detr-main\main.py", line 142, in main
    dataset_train = build_dataset(image_set='train', args=args)
  File "E:\GoogleDownload\detr-main\datasets\__init__.py", line 20, in build_dataset
    return build_coco(image_set, args)
  File "E:\GoogleDownload\detr-main\datasets\coco.py", line 157, in build
    dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks)
  File "E:\GoogleDownload\detr-main\datasets\coco.py", line 19, in __init__
    super(CocoDetection, self).__init__(img_folder, ann_file)
  File "E:\Anaconda3\envs\detr-main\lib\site-packages\torchvision\datasets\coco.py", line 36, in __init__
    self.coco = COCO(annFile)
  File "E:\Anaconda3\envs\detr-main\lib\site-packages\pycocotools\coco.py", line 82, in __init__
    dataset = json.load(f)
  File "E:\Anaconda3\envs\detr-main\lib\json\__init__.py", line 293, in load
    return loads(fp.read(),
UnicodeDecodeError: 'gbk' codec can't decode byte 0xaf in position 295131: illegal multibyte sequence

解决方法:在E:\Anaconda3\envs\detr-main\Lib\site-packages\pycocotools\coco.py文件里面加上encoding='utf-8'.

# load dataset
        self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
        self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
        if not annotation_file == None:
            print('loading annotations into memory...')
            tic = time.time()
            with open(annotation_file, 'r',encoding='UTF-8') as f:
                dataset = json.load(f)
            assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
            print('Done (t={:0.2f}s)'.format(time.time()- tic))
            self.dataset = dataset
            self.createIndex()

 第二个是读取文件路径文件名错误:

Traceback (most recent call last):
  File "E:\GoogleDownload\detr-main\main.py", line 249, in <module>
    main(args)
  File "E:\GoogleDownload\detr-main\main.py", line 197, in main
    train_stats = train_one_epoch(
  File "E:\GoogleDownload\detr-main\engine.py", line 28, in train_one_epoch
    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
  File "E:\GoogleDownload\detr-main\util\misc.py", line 223, in log_every
    for obj in iterable:
  File "E:\Anaconda3\envs\detr-main\lib\site-packages\torch\utils\data\dataloader.py", line 633, in __next__
    data = self._next_data()
  File "E:\Anaconda3\envs\detr-main\lib\site-packages\torch\utils\data\dataloader.py", line 1345, in _next_data
    return self._process_data(data)
  File "E:\Anaconda3\envs\detr-main\lib\site-packages\torch\utils\data\dataloader.py", line 1371, in _process_data
    data.reraise()
  File "E:\Anaconda3\envs\detr-main\lib\site-packages\torch\_utils.py", line 644, in reraise
    raise exception
FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "E:\Anaconda3\envs\detr-main\lib\site-packages\torch\utils\data\_utils\worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "E:\Anaconda3\envs\detr-main\lib\site-packages\torch\utils\data\_utils\fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "E:\Anaconda3\envs\detr-main\lib\site-packages\torch\utils\data\_utils\fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "E:\GoogleDownload\detr-main\datasets\coco.py", line 24, in __getitem__
    img, target = super(CocoDetection, self).__getitem__(idx)
  File "E:\Anaconda3\envs\detr-main\lib\site-packages\torchvision\datasets\coco.py", line 48, in __getitem__
    image = self._load_image(id)
  File "E:\Anaconda3\envs\detr-main\lib\site-packages\torchvision\datasets\coco.py", line 41, in _load_image
    return Image.open(os.path.join(self.root, path)).convert("RGB")
  File "E:\Anaconda3\envs\detr-main\lib\site-packages\PIL\Image.py", line 3131, in open
    fp = builtins.open(filename, "rb")
FileNotFoundError: [Errno 2] No such file or directory: 'E:\\GoogleDownload\\detr-main\\datasets\\nestdataset\\train2017\\RicardoQuanbu20181143'

解决方法:在E:\Anaconda3\envs\detr-main\Lib\site-packages\PIL\Image.py里面的3131行加上图片文件的后缀名,我的图片文件是jpg类型,所以加上.jpg

    if filename:
        fp = builtins.open(filename+'.jpg', "rb")
        exclusive_fp = True

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值