报错:
TypeError: forward() missing 5 required positional arguments: 'prior_data', 'num_classes', 'top_k', 'conf_thresh', and 'nms_thresh'
解决方案:
解决方案:
1.参数不对,test.py正常运行,eval.py报错,
把test.py代码复制到eval.py
把这段代码
#eval.py
detections = detect(
loc.view(loc.size(0), -1, 4),
softmax(conf.view(conf.size(0), -1, config.class_num)),
torch.cat([o.view(-1, 4) for o in priors],0),
).data
替换为
#eval.py
detections = detect(
loc.view(loc.size(0), -1, 4),
softmax(conf.view(conf.size(0), -1, config.class_num)),
torch.cat([o.view(-1, 4) for o in priors], 0),
config.class_num,
200,
0.7,
0.45,
).data
2.pytorch版本更新
把
#eva.py
detect = Detect(config.class_num, 0, 200, 0.01, 0.45)
替换为
#eva.py
detect = Detect.apply # pytorch新版本需要这样使用
更改之后,就可以正常运行