论文题目:ViT-V-Net: Vision Transformer for Unsupervised Volumetric Medical Image Registration
源码链接:https://github.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch



config.py:
配置初始的参数
import ml_collections
def get_3DReg_config():
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (8, 8, 8)})
config.patches.grid = (8, 8, 8)
config.hidden_size = 252
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 3072
config.transformer.num_heads = 12
config.transformer.num_layers = 12
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.patch_size = 8
config.conv_first_channel = 512
config.encoder_channels = (16, 32, 32)
config.down_factor = 2
config.down_num = 2
config.decoder_channels = (96, 48, 32, 32, 16)
config.skip_channels = (32, 32, 32, 32, 16)
config.n_dims = 3
config.n_skip = 5
return config
models.py:
Multi-head attention

Attention理论上是上图中的橙色部分
class Attention(nn.Module):
def __init__(self, config, vis):
super(Attention, self).__init__()
self.vis = vis
self.num_attention_heads = config.transformer["num_heads"]
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = Linear(config.hidden_size, self.all_head_size)
self.key = Linear(config.hidden_size, self.all_head_size)
self.value = Linear(config.hidden_size, self.all_head_size)
self.out = Linear(config.hidden_size, config.hidden_size)
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.softmax = Softmax(dim=-1)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores