参考代码
# 编译 nms
cd pytorch-retinanet/lib
bash build.sh
出现如下问题,是因为,参考的代码中 pytorch==0.4, 我使用 pytorch=1.0
ImportError: torch.utils.ffi is deprecated. Please use cpp extensions instead
解决方法:
- replace nms lib: https://github.com/huaifeng1993/NMS
- this code has been compiled.if you need compile:
cd nms
rm -rf /build
rm *.so
cd ..
python setup3.py build_ext --inplace
#at last,you need modify code in model.py:
#from lib.nms.pth_nms import pth_nms
from lib.nms.gpu_nms import gpu_nms
and
#return pth_nms(dets, thresh)
return gpu_nms(dets, thresh)
raise error:
TypeError: Argument 'dets' has incorrect type (expected numpy.ndarray, got Tensor)
我的解决方式是:
# anchors_nms_idx = nms(torch.cat([transformed_anchors, scores], dim=2)[0, :, :], 0.5)
anchors_nms_idx = nms(torch.cat([transformed_anchors, scores], dim=2)[0, :, :].cpu().numpy(), 0.5)
# 再 visualize.py 中添加
scores = scores.cpu().numpy()
上述issue中解决方式是(这种方式可能需要重新编译):
# you need change the dets to numpy:
add in gpu_nms() :
dets = dets.numpy()
cv2 退出窗口
k = cv2.waitKey(0) # waitkey代表读取键盘的输入,括号里的数字代表等待多长时间,单位ms。 0代表一直等待
if k == 27: # 键盘上Esc键的键值
cv2.destroyAllWindows()
break # 终止循环