打开models/yolo.py,修改Detect类的forward函数
代码如下:
if torch.onnx.is_in_onnx_export():
for i in range(self.nl): # 分别对三个输出层处理
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
y = x[i].sigmoid()
z.append(y.view(bs, -1, self.no))
return torch.cat(z, 1)
调用export.py,导出onnx