Pytorch实现CT图像正投影(FP)与反投影(FBP)的模块

FP/FBP Modules

有关CT图像重建或图像处理的训练任务有时需要数据在投影域和图像域上进行变换,为了能使梯度在投影域和图像域之间进行传播,需要实现Forward Projection与Back Projection模块。
参考文献中提到的平行束正反投方法,可以进行网络中正反投模块的Pytorch实现。

[1] Zhao, J. , et al. “Unsupervised Learnable Sinogram Inpainting Network (SIN) for Limited Angle CT reconstruction.” (2018).

正反投模块的原理示意图下图1所示:radon变换即正投的步骤可以看成是旋转+累加,radon反变换即反投的过程可以看成是滤波+旋转(反投)+累加的过程。radon反变换的滤波可以在时域进行也可以在频域进行,因为时域的卷积等于频域的乘积,本博文是在频率上进行得Ramp滤波。

图1 Radon变换与反变换示意图

Pytorch实现

正投模块(FP)的实现:

class FP(nn.Module):
    def __init__(self, viewNum, chanNum, batchSize):
        super(FP, self).__init__()
        self.viewNum = viewNum
        self.chanNum = chanNum
        self.batchSize = batchSize
    def forward(self, x):
        '''
            x: image 
            x is a tensor (batchSize*netChanNum*imgSize*imgSize)
        '''
        sino = torch.from_numpy( np.zeros((self.batchSize, 1, self.chanNum, self.viewNum))).type(torch.FloatTensor) # batchSize*channel*512*360
        sino = sino.cuda()
        ''' rotate'''
        for i in range(self.viewNum):
            angle = - 180/self.viewNum*(i+1) * math.pi / 180 - math.pi
            A = np.array([[np.cos(angle), -np.sin(angle)],
                          [np.sin(angle), np.cos(angle)]])  
            theta = np.array([[A[0, 0], A[0, 1], 0], [A[1, 0], A[1, 1], 0]])                                    
            theta = torch.from_numpy(theta).type(torch.FloatTensor)
            theta = theta.unsqueeze(0)
            theta = theta.repeat(self.batchSize,1,1)
            theta = theta.cuda()
            ''' interpolation'''
            grid = F.affine_grid(theta, x.size())
            x_rotate = F.grid_sample(x, grid) # 4*1*512*512
            ''' accumulation'''
            sino[:,:,:,i] = torch.sum(x_rotate, dim=2) 
        sino = sino*0.5
        sino = sino.cuda()    
        return sino

反投模块(FBP)的实现:

class FBP(nn.Module):
    def __init__(self, viewNum, chanNum, batchSize, netChanNum, chanSpacing):
        super(FBP, self).__init__()
        self.viewNum = viewNum # projection的投影角度数
        self.chanNum = chanNum # projection的通道数
        self.batchSize = batchSize
        self.netChanNum = netChanNum # 输入FBP网络数据的通道数
        self.chanSpacing = chanSpacing
    
    def forward(self, x):
        '''
            x:  projection (batchSize*netChanNum*chanNum*viewNum) 4*1*512*360
            type(x) is a tensor
        '''
        '''频域滤波'''
        projectionValue = convolution(x,self.batchSize,self.netChanNum,self.chanNum,self.viewNum,self.chanSpacing) # 2*1*512*360
        projectionValue = projectionValue.cuda()
        sino_rotate = np.zeros((self.batchSize, self.netChanNum, self.viewNum, self.chanNum, self.chanNum)) # batchSize*netChanNum*viewNum*chanNum*chanNum
        sino_rotate = torch.from_numpy(sino_rotate).type(torch.FloatTensor)
        sino_rotate = sino_rotate.cuda()
        AglPerView = math.pi/self.viewNum
        '''设置FOV,生成mask将FOV以外的区域置零'''
        FOV = torch.ones((self.batchSize,self.netChanNum,self.chanNum,self.chanNum))
        x_linespace = np.arange(1,self.chanNum+1,1)  # (512,)
        y_linespace = np.arange(1,self.chanNum+1,1)  # (512,)
        x_mesh,y_mesh = np.meshgrid(x_linespace,y_linespace) # 512*512
        XPos = (x_mesh-256.5) * self.chanSpacing # 512*512
        YPos = (y_mesh-256.5) * self.chanSpacing # 512*512
        R = np.sqrt(XPos**2 + YPos**2) # 512*512
        R = torch.from_numpy(R).type(torch.FloatTensor) # 512*512
        R = R.repeat(self.batchSize,self.netChanNum,1,1) # 2*1*512*512
        FOV[R>=self.chanSpacing*self.chanNum/2] = 0 # 2*1*512*512
        FOV = FOV.cuda()
        ''' rotate interpolation'''
        for i in range(self.viewNum):
            projectionValueFiltered = torch.unsqueeze(projectionValue[:,:,:,i],3) # 2*1*512*1
            projectionValueRepeat = projectionValueFiltered.repeat(1,1,1,512) # 2*1*512*512
            projectionValueRepeat = projectionValueRepeat * FOV  # 2*1*512*512
            angle = -math.pi/2 + 180/self.viewNum*(i+1) * math.pi / 180
            A = np.array([[np.cos(angle), -np.sin(angle)],
                          [np.sin(angle), np.cos(angle)]])
            theta = np.array([[A[0, 0], A[0, 1], 0], [A[1, 0], A[1, 1], 0]])
            theta = torch.from_numpy(theta).type(torch.FloatTensor)
            theta = theta.unsqueeze(0)
            theta = theta.cuda()
            theta = theta.repeat(self.batchSize,1,1)
            grid = F.affine_grid(theta, torch.Size((self.batchSize, self.netChanNum, 512, 512)))
            sino_rotate[:,:,i,:,:] = F.grid_sample(projectionValueRepeat, grid) # 512*512
        ''' accumulation'''
        iradon = torch.sum(sino_rotate, dim=2)  
        iradon = iradon*AglPerView
        return iradon

频域滤波的实现:其中调用到的Ramp()函数参照博文

https://blog.csdn.net/kouwang9779/article/details/115961582

def convolution(proj,batchSize,netChann,channNum,viewnum,channSpacing):
    AglPerView = np.pi/viewnum
    channels = 512
    origin = np.zeros((batchSize,netChann,viewnum, channels, channels))
    # avoid truncation
    step = list(np.arange(0,1,1/100))
    step2 = step.copy()
    step2.reverse()
    step = np.array(step) # (100,)
    step = np.expand_dims(step,axis=1) # 100*1
    step = torch.from_numpy(step).type(torch.FloatTensor) # (100,1)
    step = step.repeat(batchSize,1,1,360) # 2*1*100*360
    step_temp = proj[:,:,0,:].unsqueeze(2) # 2*1*1*360
    step_temp = step_temp.repeat(1,1,100,1) # 2*1*100*360
    step = step.cuda()
    step = step*step_temp # 2*1*100*360
    step2 = np.array(step2) # (100,)
    step2 = np.expand_dims(step2,axis=1) # 100*1
    step2 = torch.from_numpy(step2).type(torch.FloatTensor) # (100,1)
    step2 = step2.repeat(batchSize,1,1,360) # 2*1*100*360
    step2_temp = proj[:,:,-1,:].unsqueeze(2) # 2*1*1*360
    step2_temp = step2_temp.repeat(1,1,100,1) # 2*1*100*360
    step2 = step2.cuda()
    step2 = step2*step2_temp # 2*1*100*360
    filterData = Ramp(batchSize,netChann,2*100+channNum,channSpacing) # 2*1*2048*360
    iLen = filterData.shape[2] # 2048
    proj = torch.cat((step,proj,step2),2) # 2*1*712*360
    proj = torch.cat((proj,torch.zeros(batchSize,netChann,iLen-proj.shape[2],viewnum).cuda()),2) # 2*1*2048*360
    sino_fft = fft(proj.detach().cpu().numpy(),axis=2) # 2*1*2048*360
    image_filter = filterData*sino_fft # 2*1*2048*360
    image_filter_ = ifft(image_filter,axis=2) # 2*1*2048*360
    image_filter_ = np.real(image_filter_)
    image_filter_ = torch.from_numpy(image_filter_).type(torch.FloatTensor)
    image_filter_final = image_filter_[:,:,100:512+100] # 2*1*512*360
    return image_filter_final

结果展示


图2 原图像(左),通过FP模块后的正投结果(中),正投结果通过FBP模块后的反投结果(右)

讨论

以这种方法进行正反投之后的误差即原图像与反投结果之差如图3所示,这个误差范围应该可以通过网络的训练进行弥补。

图3 原图像与反投影图像的误差图

  • 3
    点赞
  • 57
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值