用PyTorch来做物体检测和追踪
原文:https://towardsdatascience.com/object-detection-and-tracking-in-pytorch-b3cf1a696a98
翻译:https://ai.yanxishe.com/page/TextTranslation/1333
源码:https://github.com/cfotache/pytorch_objectdetecttrack
乌班图16.04
pytorch0.4
cuda9.0
出现的坑:
iou_matrix[d,t] = iou(det,trk)
RuntimeError: Expected object of type torch.FloatTensor but found type torch.DoubleTensor for argument #3 ‘other’
意思是要求输入torch.DoubleTensor类型,现在输入的是torch.FloatTensor类型(容易会错意,往相反的方向努力),那么只需要把输入的量的类型改变一下就行.
iou_matrix[d,t] = iou(det,trk)上一行加上
det=det.double()
trk=trk.double()
或者
det = det.type(torch.DoubleTensor)
trk = trk.type(torch.DoubleTensor)
再次报错
trk=trk.double()
AttributeError: ‘numpy.ndarray’ object has no attribute ‘double’
发现trk是numpy.ndarray格式,再上一行加上转换格式的trk = torch.from_numpy(trk)
,错误解决