torch做目标检测时报错 RuntimeError: CUDA error: device-side assert triggered 原因和解决方法

用自己的数据训练torchvision的maskrcnn时候,报错如下:

Traceback (most recent call last):
  File "main_train_detection.py", line 232, in <module>
    main(params)
  File "main_train_detection.py", line 201, in main
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
  File "/raid/huaqing/tyler/suzhou/code/utils/engine.py", line 37, in train_one_epoch
    loss_dict = model(images, targets)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/generalized_rcnn.py", line 97, in forward
    detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/roi_heads.py", line 760, in forward
    loss_classifier, loss_box_reg = fastrcnn_loss(
  File "/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/roi_heads.py", line 40, in fastrcnn_loss
    sampled_pos_inds_subset = torch.where(labels > 0)[0]
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

发现根本原因是, 类别标签没有从0开始编号:
我的要识别的目标其实共有3类,所以设置的类别总数是3. 然后设置类别标签和类别的对应关系分别是:

cls_dict = {'holes':1, 'marker':2, 'band':3}.

在类别标签(label)编号的时候,其实是从0开始编号的, 对于总共3类别的情况,则label编号分别是0,1,2. 也就是说并没有label==3这个类别.所以采用上述cls_dict,将导致band类的编号溢出. 应更正如下:

cls_dict = {'holes':0, 'marker':1, 'band':2}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值