这是递归级联网络论文中提到的仿射配准网络的pytorch版本代码,源代码是tensorflow版本
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def Conv(in_chn, out_chn, kernel_size, stride, padding):
return nn.Conv3d(in_chn, out_chn, kernel_size, stride, padding)
def ConvReLU(in_chn, out_chn, kernel_size, stride, padding):
return nn.Sequential(
nn.Conv3d(in_chn, out_chn, kernel_size, stride, padding),
nn.ReLU()
)
def ConvLeakyReLU(in_chn, out_chn, kernel_size, stride, padding, alpha=0.1):
return nn.Sequential(
nn.Conv3d(in_chn, out_chn, kernel_size, stride, padding),
nn.LeakyReLU(alpha)
)
def UpConv(in_chn, out_chn, kernel_size, stride, padding):
return nn.ConvTranspose3d(in_chn, out_chn, kernel_size, stride, padding,)
def UpConvReLU(in_chn, out_chn, kernel_size, stride, padding):
return nn.Sequential(
nn.ConvTranspose3d(in_chn, out_chn, kernel_size, stride, padding),
nn.ReLU()
)
def UpConvLeakyReLU(in_chn, out_chn, kernel_size, stride, padding, alpha=0.1):
return nn.Sequential(
nn.ConvTranspose3d(in_chn, out_chn, kernel_size, stride, padding),
nn.LeakyReLU(alpha)
)
def affine_flow(W, b, len1, len2, len3, device):
"""
W: [1,3,3] tensor-order-independent
b: [1,3] tensor-order-independent
len1: the length of D, or dim1 of the volume
len2: the length of H, or dim2...
len3: the length of W, or dim3...
the function itself will generate the tensor-order of NDHWC when running, so TRANSPOSE IS NEEDED AFTER affine_flow
"""
# N C D H W
b = torch.reshape(b, [-1, 1, 1, 1, 3])
xr = torch.arange(-(len1 - 1) / 2.0, len1 / 2.0, 1.0).to(device)
xr = torch.reshape(xr, [1, -1, 1, 1, 1])
yr = torch.arange(-(len2 - 1) / 2.0, len2 / 2.0, 1.0).to(device)
yr = torch.reshape(yr, [1, 1, -1, 1, 1])
zr = torch.arange(-(len3 - 1) / 2.0, len3 / 2.0, 1.0).to(device)
zr = torch.reshape(zr, [1, 1, 1, -1, 1])
wx = W[:, :, 0]
wx = torch.reshape(wx, [-1, 1, 1, 1, 3])
wy = W[:, :, 1]
wy = torch.reshape(wy, [-1, 1, 1, 1, 3])
wz = W[:, :, 2]
wz = torch.reshape(wz, [-1, 1, 1, 1, 3])
return (xr * wx + yr * wy) + (zr * wz + b)
def det3x3(M):
"""
M: [1,3,3] tensor-order-independent
"""
M = [[M[:, i, j] for j in range(3)] for i in range(3)]
return (
M[0][0] * M[1][1] * M[2][2] +
M[0][1] * M[1][2] * M[2][0] +
M[0][2] * M[1][0] * M[2][1]) - (
M[0][0] * M[1][2] * M[2][1] +
M[0][1] * M[1][0] * M[2][2] +
M[0][2] * M[1][1] * M[2][0])
def elem_sym_polys_of_eigen_values(M):
"""
M: [1,3,3] tensor-order-independent
"""
M = [[M[:, i, j] for j in range(3)] for i in range(3)]
sigma1 = M[0][0] + M[1][1] + M[2][2]
sigma2 = (
M[0][0] * M[1][1] +
M[1][1] * M[2][2] +
M[2][2] * M[0][0]) - \
(M[0][1] * M[1][0] +
M[1][2] * M[2][1] +
M[2][0] * M[0][2])
sigma3 = (
M[0][0] * M[1][1] * M[2][2] +
M[0][1] * M[1][2] * M[2][0] +
M[0][2] * M[1][0] * M[2][1]) - \
(M[0][0] * M[1][2] * M[2][1] +
M[0][1] * M[1][0] * M[2][2] +
M[0][2] * M[1][1] * M[2][0])
return sigma1, sigma2, sigma3
class VTNAffineStem(nn.Module):
def __init__(self, flow_multiplier=1):
super(VTNAffineStem, self).__init__()
self.flow_multiplier = flow_multiplier
self.dummy_param = nn.Parameter(torch.empty(0))
self.conv1 = ConvLeakyReLU(2, 16, 3, 2, 1)
self.conv2 = ConvLeakyReLU(16, 32, 3, 2, 1)
self.conv3 = ConvLeakyReLU(32, 64, 3, 2, 1)
self.conv3_1 = ConvLeakyReLU(64, 64, 3, 1, 1)
self.conv4 = ConvLeakyReLU(64, 128, 3, 2, 1)
self.conv4_1 = ConvLeakyReLU(128, 128, 3, 1, 1)
self.conv5 = ConvLeakyReLU(128, 256, 3, 2, 1)
self.conv5_1 = ConvLeakyReLU(256, 256, 3, 1, 1)
self.conv6 = ConvLeakyReLU(256, 512, 3, 2, 1)
self.conv6_1 = ConvLeakyReLU(512, 512, 3, 1, 1)
self.conv7_W = nn.Conv3d(512, 9, [1, 3, 3], 1)
self.conv7_b = nn.Conv3d(512, 3, [1, 3, 3], 1)
def forward(self, img1, img2):
device = self.dummy_param.device
# print(device)
x = torch.cat([img1, img2], dim=1)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv3_1(x)
x = self.conv4(x)
x = self.conv4_1(x)
x = self.conv5(x)
x = self.conv5_1(x)
x = self.conv6(x)
x = self.conv6_1(x)
x_conv7_W = self.conv7_W(x)
x_conv7_b = self.conv7_b(x)
I = torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]).float().to(device)
W = torch.reshape(x_conv7_W, [-1, 3, 3]) * self.flow_multiplier
b = torch.reshape(x_conv7_b, [-1, 3]) * self.flow_multiplier
A = W + I
sx, sy, sz = img1.shape[2:]
# there is no NDHWC order in inputs, and the order is generated by the affine_flow function
flow = affine_flow(W, b, sx, sy, sz, device)
flow = flow.transpose(1, -1) # convert NDHWC(TF style) to NCDHW(PyTorch style)
det = det3x3(A)
det_loss = (det - 1.0)**2 / 2.0
det_loss = torch.sum(det_loss)
eps = 1e-5
epsI = [[[eps * elem for elem in row] for row in Mat] for Mat in I]
epsI = torch.tensor(epsI).float().to(device)
C = torch.bmm(A.transpose(1, -1), A) + epsI
s1, s2, s3 = elem_sym_polys_of_eigen_values(C)
ortho_loss = s1 + (1 + eps) * (1 + eps) * s2 / s3 - 3 * 2 * (1 + eps)
ortho_loss = torch.sum(ortho_loss)
return {
'flow': flow,
'W': W,
'b': b,
'det_loss': det_loss,
'ortho_loss': ortho_loss
}
if __name__ == '__main__':
dummy_input1 = torch.randn(1, 1, 64, 192, 192).cuda()
dummy_input2 = torch.randn(1, 1, 64, 192, 192).cuda()
model = VTNAffineStem().cuda()
out = model(dummy_input1, dummy_input2)
print(out.keys())
调了半天,还得是师兄啊