场景:
做目标检测,从网上找了人家的模型,调用时报错
can not import nms from torchvison.ops 。
查看源码:
路径 python_path \Lib\site-packages\torchvision\ops\boxes.py
原始写法如下:
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
"""
Performs non-maximum suppression (NMS) on the boxes according
to their intersection-over-union (IoU).
NMS iteratively removes lower scoring boxes which have an
IoU greater than iou_threshold with another (higher scoring)
box.
If multiple boxes have the exact same score and satisfy the IoU
criterion with respect to a reference box, the selected box is
not guaranteed to be the same between CPU and GPU. This is similar
to the behavior of argsort in PyTorch when repeated values are present.
Args:
boxes (Tensor[N, 4])): boxes to perform NMS on. They
are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and
``0 <= y1 < y2``.
scores (Tensor[N]): scores for each one of the boxes
iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold
Returns:
Tensor: int64 tensor with the indices of the elements that have been kept
by NMS, sorted in decreasing order of scores
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(nms)
_assert_has_ops()
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
解决方法:
重写nms
把上面代码注释掉,把下面代码粘贴上。ok
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float):
"""
:param boxes: [N, 4], 此处传进来的框,是经过筛选(NMS之前选取过得分TopK)之后, 在传入之前处理好的;
:param scores: [N]
:param iou_threshold: 0.7
:return:
"""
keep = [] # 最终保留的结果, 在boxes中对应的索引;
idxs = scores.argsort() # 值从小到大的 索引
while idxs.numel() > 0: # 循环直到null; numel(): 数组元素个数
# 得分最大框对应的索引, 以及对应的坐标
max_score_index = idxs[-1]
max_score_box = boxes[max_score_index][None, :] # [1, 4]
keep.append(max_score_index)
if idxs.size(0) == 1: # 就剩余一个框了;
break
idxs = idxs[:-1] # 将得分最大框 从索引中删除; 剩余索引对应的框 和 得分最大框 计算IoU;
other_boxes = boxes[idxs] # [?, 4]
ious = box_iou(max_score_box, other_boxes) # 一个框和其余框比较 1XM
idxs = idxs[ious[0] <= iou_threshold]
keep = idxs.new(keep) # Tensor
return keep