info
主要是为了投票
mmseg/models/segmentors/base.py
里面 forward
if return_loss:
return self.forward_train(img, img_metas, **kwargs)
else:
return self.forward_test(img, img_metas, **kwargs)
forward_test
里面
if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], **kwargs)
else:
# print('Using aug_test')
return self.aug_test(imgs, img_metas, **kwargs)
最后是在 mmseg/models/segmentors/encoder_decoder.py
里面实现了 simple_test
和 aug_test
,同时 EncoderDecoder
里面是有test_cfg
的 那么可以简单传个参数
def simple_test(self, img, img_meta, rescale=True):
"""Simple test with single image."""
seg_logit = self.inference(img, img_meta, rescale)
if self.test_cfg.get("return_logits", False):
seg_pred = seg_logit
else:
seg_pred = seg_logit.argmax(dim=1)
# seg_pred = seg_logit.argmax(dim=1)
if torch.onnx.is_in_onnx_export():
# our inference backend only support 4D output
seg_pred = seg_pred.unsqueeze(0)
return seg_pred
seg_pred = seg_pred.cpu().numpy()
# unravel batch dim
seg_pred = list(seg_pred)
return seg_pred
效果
from mmseg.apis import init_segmentor, inference_segmentor
config_file = '/home/user/workplace/python/202201kesai_seg/configs/upernet.py'
checkpoint_file = '/home/user/workplace/python/mmsegmentation-0.19.0/work_dirs/2022kesai_seg/upernet/latest.pth'
model = init_segmentor(config_file, checkpoint_file, device='cuda:0')
result = inference_segmentor(model, '/home/user/data/202201_kesai_seg/images/000893_GF.tif')
print(result[0].shape)
(7, 256, 256)
或者 用 --out probs.pkl
也可以得到probs