Recursive Cascaded Networks 添加测试代码(pytorch)

Recursive Cascaded Networks for Unsupervised Medical Image Registration 论文

论文也很容易理解,看了代码就完全明白了,很清晰易懂。

pytorch版本的改动适用于 2D3D 图像,但是中间的图像显示仅适用于 2D 图像,自行去掉就好。源码无法直接运行,跟着报错稍作改动。本人主要添加了加载模型的问题,如果想从某个epoch训练或验证的时候如何加载网络。

class RecursiveCascadeNetwork(nn.Module):
    """For Training"""
    def __init__(self, n_cascades, im_size, device, state_dict=None, testing=False):
        super(RecursiveCascadeNetwork, self).__init__()

        self.stems = []
        # See note in base_networks.py about the assumption in the image shape
        init_model = VTNAffineStem(dim=len(im_size), im_size=im_size[0]).to(device)
        self.stems.append(init_model)
        for i in range(n_cascades):
            model = VTN(dim=len(im_size), flow_multiplier=1.0 / n_cascades).to(device)
            self.stems.append(model)

        self.reconstruction = SpatialTransform(im_size).to(device)

        # 如果有模型,则加载已有模型
        if state_dict:
            for i, m in enumerate(self.stems):
                m.load_state_dict(state_dict[f'cascade {i}'])

        if testing:
            for m in self.stems:
                m.eval()
            self.reconstruction.eval()

    def forward(self, fixed, moving):
        flows = []
        stem_results = []
        # Affine registration
        flow = self.stems[0](fixed, moving)
        stem_results.append(self.reconstruction(moving, flow))
        flows.append(flow)
        for model in self.stems[1:]: # cascades
            # registration between the fixed and the warped from last cascade
            flow = model(fixed, stem_results[-1])
            stem_results.append(self.reconstruction(stem_results[-1], flow))
            flows.append(flow)

        return stem_results, flows

checkpoint = 'ckp/model_wts/epoch_120.pth'
state_dict = torch.load(checkpoint)

如果从中间开始训练,引入 state_dict 即可,如果测试,将 testing 改为 True

具体的结果还在改一些细节,总是有点小问题,目前还没有确定,待确定就有结果了,迫不及待。

[1] Zhao S , Y Dong, Chang E , et al. Recursive Cascaded Networks for Unsupervised Medical Image Registration[C]// 2019 IEEE/CVF International Conference on Computer Vision (ICCV). IEEE, 2020.
[2] Zhao S , Lau T , J Luo, et al. Unsupervised 3D End-to-End Medical Image Registration With Volume Tweening Network[J]. IEEE Journal of Biomedical and Health Informatics, 2020, 24(5):1394-1404.

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值