看到了导出的onnx自动执行了sofmax操作,不想要这个操作
修改如下代码:改为
在encoder_decoder_mapy.py
def simple_test(self, img, img_meta, rescale=True): """Simple test with single image.""" # MOD # seg_logit = self.inference(img, img_meta, rescale) seg_logit, maps = self.inference(img, img_meta, rescale) if torch.onnx.is_in_onnx_export(): # our inference backend only support 4D output seg_pred = seg_logit.unsqueeze(0) return seg_pred seg_pred = seg_logit.argmax(dim=1) seg_pred = seg_pred.cpu().numpy() # unravel batch dim seg_pred = list(seg_pred) # MOD # return seg_pred return seg_pred, maps
torch.onnx.is_in_onnx_export()是判断是否是onnx输出流,将argmax操作放到后面就可以了