pytorch padding_【开源计划】图像配准中常用损失函数的pytorch实现

前言

按照开源计划的预告,我们首先从基于深度学习的图像配准任务中常用的损失函数的代码实现开始。从我最开始的那一篇博客,即基于深度学习的医学图像配准综述,可以看出,目前基于无监督学习的图像非刚性配准模型成为了一个比较流行的研究方向。这是因为基于监督学习的方法过分依赖传统方法或者模拟变形的方法来提供监督信息,这样既吃力又不讨好。以我的探索经验来讲,以传统配准方法产生的变形场作为监督信息,对网络进行训练,很容易造成过拟合问题。因此,我综合文献综述的结论与探索的经验(有兴趣的话,我可以总结一下我的探索经历以及经验教训),最终选择了基于无监督学习的配准模型,则本文主要介绍这种模型框架下常用的损失函数。实际上,监督学习的损失函数也比较简单,只需要使用深度学习框架(如TensorFlow、PyTorch)提供的函数计算误差即可,本文使用PyTorch进行实现。

损失函数

基于无监督学习的图像非刚性配准模型的损失函数通常是由两部分组成,一个是参考图像与变形后的浮动图像的相似性测度,一个是网络预测变形场的空间正则化。以比较有名的VoxelMorph为例,它的GitHub仓库可以点击此链接。按照他最早的发表在CVPR上的论文,损失函数如下:

v2-8d321681c930cf6136bc04a37c986bdc_b.jpg


其中第一项就是相似性测度,后面一项就是空间正则化项,用以约束变形场的空间平滑性。下面我们分别对其进行介绍。

相似性测度

常用于测量图像的相似性测度有三个,一个是图像灰度的均方差(mean squared voxel differece),一个是交叉互相关(cross-correlation),一个是互信息(mutual information)。前两个通常用于单模态的图像,而第一个的鲁棒性相比于交叉互相关更差一些,比较容易受图像灰度分布与对比度等的影响。互信息通常用于多模态的图像,在单模态图像的鲁棒性更好,但是到目前为止还没有发现它被用于深度学习网络训练的损失函数中,我的猜想是互信息的计算是基于统计的,不方便进行梯度计算,与反向传播原则相违背。(该看法亟待进一步的考证。好久没看Voxelmorph的开源代码,现在已经有互信息的实现了,有时间可以研究一下)

因此,主要是]使用交叉互相关作为图像配准的损失函数。交叉互相关的公式为:(摘自VoxelMorph)

v2-e6bb08de28646d0cfd9528c994f6b652_b.jpg

​他们的代码实现-TensorFlow版请查看链接中的NCC。需要指出的是,在他们的实现版本当中,他们对于三维图像使用了一个9*9*9的窗口来计算相似性,因此成为local cross-correlation,即局部交叉互相关。(没想到现在voxelmorph还提供了pytorch版本的代码,真周到,见链接)

这里展示一下我自己参考开源代码,转写成pytorch实现的局部互相关,如下:

首先,是导入依赖库

import 

接着,是局部互相关

class LCC(nn.Module):
    """
    local (over window) normalized cross correlation (square)
    """
    def __init__(self, win=[9, 9], eps=1e-5):
        super(LCC, self).__init__()
        self.win = win
        self.eps = eps
        
    def forward(self, I, J):
        I2 = I.pow(2)
        J2 = J.pow(2)
        IJ = I * J
        
        filters = Variable(torch.ones(1, 1, self.win[0], self.win[1]))
        if I.is_cuda:#gpu
            filters = filters.cuda()
        padding = (self.win[0]//2, self.win[1]//2)
        
        I_sum = F.conv2d(I, filters, stride=1, padding=padding)
        J_sum = F.conv2d(J, filters, stride=1, padding=padding)
        I2_sum = F.conv2d(I2, filters, stride=1, padding=padding)
        J2_sum = F.conv2d(J2, filters, stride=1, padding=padding)
        IJ_sum = F.conv2d(IJ, filters, stride=1, padding=padding)
        
        win_size = self.win[0]*self.win[1]
 
        u_I = I_sum / win_size
        u_J = J_sum / win_size
        
        cross = IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size
        I_var = I2_sum - 2 * u_I * I_sum + u_I*u_I*win_size
        J_var = J2_sum - 2 * u_J * J_sum + u_J*u_J*win_size
 
        cc = cross*cross / (I_var*J_var + self.eps)#np.finfo(float).eps
        lcc = -1.0 * torch.mean(cc) + 1
        return lcc

​除此之外,我还按照交叉互相关的定义,实现了一个全局的交叉互相关,即对整幅图计算,不依赖窗口。当图像尺寸较大时,全局交叉互相关的敏感度不如局部交叉互相关,实际效果不如局部的,其代码如下:

class 

空间正则化

在训练网络过程中,通过最大化图像的相似性测度,往往使网络产生不连续的变形场,通常要对预测的变形场施加一个空间平滑性的约束,即对变形场的空间梯度进行惩罚,如voxelmorph中的空间正则化,即计算变形场梯度的L2范数的平方:​

v2-f65e4c14fbd7900d702b6f499f6c7ab2_b.jpg

这里展示一下我自己参考开源代码,转写成pytorch实现的空间正则化,如下:

实际上,按照这种空间正则化的思想,还有其他几种方法,比如,最近的一篇期刊论文使用了一种称为折叠惩罚(bending penalty)的正则化方法,实际上就是计算变形场的二阶梯度,按照字面意思,它的目的是对变形场中的折叠进行惩罚,其公式如下:

v2-20df3585dcb81dfaf4b2faecd7b0efab_b.jpg


我按照公式实现的二维的折叠惩罚项如下,有兴趣的可以自己实现一下三维版的。

class 

另外,还有文章研究了将变形场的L1范数进行空间正则化,它的效果是尽可能减小形变的绝对值。

仿射变换的损失函数

最后,我还实现了一下仿射变换的损失函数,就是将仿射变换的参数与恒等变换的参数之间的差异求L1或L2范数。代码如下:

class 

​结束语

最后,以上仅供参考,欢迎各位网友批评指正与留言交流。如果对你有帮助,请点赞告诉我呀,我会更有动力写相关的文章。

想看配准介绍的视频的同学,请移步我的B站账号:

Bilibili 萌新up主:爱分享的毛毛Timmy的主页​space.bilibili.com

里面有我的项目介绍视频,后期也会更新更多内容,请关注我不迷路哦,如果对你有亿点点帮助,就多一键三连与我互动吧,谢谢~

版权声明

本文为CSDN博主「爱分享的毛毛Timmy」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。

原文链接:【开源计划】图像配准中常用损失函数的pytorch实现_Timmymm的博客-CSDN博客_医学图像配准 pytorch实现

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值