仿射配准pytorch代码

这是递归级联网络论文中提到的仿射配准网络的pytorch版本代码,源代码是tensorflow版本

import torch 
import torch.nn as nn 
import torch.nn.functional as F

import numpy as np 

def Conv(in_chn, out_chn, kernel_size, stride, padding):
    return nn.Conv3d(in_chn, out_chn, kernel_size, stride, padding)

def ConvReLU(in_chn, out_chn, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv3d(in_chn, out_chn, kernel_size, stride, padding),
        nn.ReLU()
    )

def ConvLeakyReLU(in_chn, out_chn, kernel_size, stride, padding, alpha=0.1):
    return nn.Sequential(
        nn.Conv3d(in_chn, out_chn, kernel_size, stride, padding),
        nn.LeakyReLU(alpha)
    )

def UpConv(in_chn, out_chn, kernel_size, stride, padding):
    return nn.ConvTranspose3d(in_chn, out_chn, kernel_size, stride, padding,)

def UpConvReLU(in_chn, out_chn, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose3d(in_chn, out_chn, kernel_size, stride, padding),
        nn.ReLU()
    )

def UpConvLeakyReLU(in_chn, out_chn, kernel_size, stride, padding, alpha=0.1):
    return nn.Sequential(
        nn.ConvTranspose3d(in_chn, out_chn, kernel_size, stride, padding),
        nn.LeakyReLU(alpha)
    )

def affine_flow(W, b, len1, len2, len3, device):    
    """
    W: [1,3,3]  tensor-order-independent
    b: [1,3]    tensor-order-independent
    len1: the length of D, or dim1 of the volume
    len2: the length of H, or dim2... 
    len3: the length of W, or dim3...

    the function itself will generate the tensor-order of NDHWC when running, so TRANSPOSE IS NEEDED AFTER affine_flow
    """
    # N C D H W 
    b = torch.reshape(b, [-1, 1, 1, 1, 3])
    xr = torch.arange(-(len1 - 1) / 2.0, len1 / 2.0, 1.0).to(device)
    xr = torch.reshape(xr, [1, -1, 1, 1, 1])
    yr = torch.arange(-(len2 - 1) / 2.0, len2 / 2.0, 1.0).to(device)
    yr = torch.reshape(yr, [1, 1, -1, 1, 1])
    zr = torch.arange(-(len3 - 1) / 2.0, len3 / 2.0, 1.0).to(device)
    zr = torch.reshape(zr, [1, 1, 1, -1, 1])
    wx = W[:, :, 0]
    wx = torch.reshape(wx, [-1, 1, 1, 1, 3])
    wy = W[:, :, 1]
    wy = torch.reshape(wy, [-1, 1, 1, 1, 3])
    wz = W[:, :, 2]
    wz = torch.reshape(wz, [-1, 1, 1, 1, 3])
    return (xr * wx + yr * wy) + (zr * wz + b)

def det3x3(M):
    """
    M: [1,3,3] tensor-order-independent
    """

    M = [[M[:, i, j] for j in range(3)] for i in range(3)]
    return (
        M[0][0] * M[1][1] * M[2][2] + 
        M[0][1] * M[1][2] * M[2][0] + 
        M[0][2] * M[1][0] * M[2][1]) - (
            M[0][0] * M[1][2] * M[2][1] + 
            M[0][1] * M[1][0] * M[2][2] + 
            M[0][2] * M[1][1] * M[2][0])


def elem_sym_polys_of_eigen_values(M):
    """
    M: [1,3,3] tensor-order-independent
    """
    
    M = [[M[:, i, j] for j in range(3)] for i in range(3)]
    sigma1 = M[0][0] + M[1][1] + M[2][2]
    sigma2 = (
        M[0][0] * M[1][1] + 
        M[1][1] * M[2][2] + 
        M[2][2] * M[0][0]) - \
        (M[0][1] * M[1][0] +
            M[1][2] * M[2][1] + 
            M[2][0] * M[0][2])
    sigma3 = (
        M[0][0] * M[1][1] * M[2][2] + 
        M[0][1] * M[1][2] * M[2][0] + 
        M[0][2] * M[1][0] * M[2][1]) - \
        (M[0][0] * M[1][2] * M[2][1] + 
        M[0][1] * M[1][0] * M[2][2] + 
        M[0][2] * M[1][1] * M[2][0])
    return sigma1, sigma2, sigma3


class VTNAffineStem(nn.Module):
    def __init__(self, flow_multiplier=1):
        super(VTNAffineStem, self).__init__()
        self.flow_multiplier = flow_multiplier
        self.dummy_param = nn.Parameter(torch.empty(0))
        
        self.conv1 = ConvLeakyReLU(2, 16, 3, 2, 1)
        self.conv2 = ConvLeakyReLU(16, 32, 3, 2, 1)
        self.conv3 = ConvLeakyReLU(32, 64, 3, 2, 1)
        self.conv3_1 = ConvLeakyReLU(64, 64, 3, 1, 1)
        self.conv4 = ConvLeakyReLU(64, 128, 3, 2, 1)
        self.conv4_1 = ConvLeakyReLU(128, 128, 3, 1, 1)
        self.conv5 = ConvLeakyReLU(128, 256, 3, 2, 1)
        self.conv5_1 = ConvLeakyReLU(256, 256, 3, 1, 1)
        self.conv6 = ConvLeakyReLU(256, 512, 3, 2, 1)
        self.conv6_1 = ConvLeakyReLU(512, 512, 3, 1, 1)
        self.conv7_W = nn.Conv3d(512, 9, [1, 3, 3], 1)
        self.conv7_b = nn.Conv3d(512, 3, [1, 3, 3], 1)

    def forward(self, img1, img2):
        device = self.dummy_param.device
        # print(device)
        x = torch.cat([img1, img2], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv3_1(x)
        x = self.conv4(x)
        x = self.conv4_1(x)
        x = self.conv5(x)
        x = self.conv5_1(x)
        x = self.conv6(x)
        x = self.conv6_1(x)
        x_conv7_W = self.conv7_W(x)
        x_conv7_b = self.conv7_b(x)

        I = torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]).float().to(device)

        W = torch.reshape(x_conv7_W, [-1, 3, 3]) * self.flow_multiplier
        b = torch.reshape(x_conv7_b, [-1, 3]) * self.flow_multiplier

        A = W + I 
        
        sx, sy, sz = img1.shape[2:]
        # there is no NDHWC order in inputs, and the order is generated by the affine_flow function
        flow = affine_flow(W, b, sx, sy, sz, device)    
        flow = flow.transpose(1, -1) # convert NDHWC(TF style) to NCDHW(PyTorch style)
        det = det3x3(A)
        det_loss = (det - 1.0)**2 / 2.0
        det_loss = torch.sum(det_loss)

        eps = 1e-5
        epsI = [[[eps * elem for elem in row] for row in Mat] for Mat in I]
        epsI = torch.tensor(epsI).float().to(device)
        C = torch.bmm(A.transpose(1, -1), A) + epsI
        s1, s2, s3 = elem_sym_polys_of_eigen_values(C)
        ortho_loss = s1 + (1 + eps) * (1 + eps) * s2 / s3 - 3 * 2 * (1 + eps)
        ortho_loss = torch.sum(ortho_loss)

        return {
                'flow': flow,
                'W': W,
                'b': b,
                'det_loss': det_loss,
                'ortho_loss': ortho_loss
            }


if __name__ == '__main__':
    dummy_input1 = torch.randn(1, 1, 64, 192, 192).cuda()
    dummy_input2 = torch.randn(1, 1, 64, 192, 192).cuda()
    model = VTNAffineStem().cuda()
    out = model(dummy_input1, dummy_input2)
    print(out.keys())

    

调了半天,还得是师兄啊

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值