介绍
最近打算跑一些3D医学影像分割的实验,并且做一些记录。
分割网络,像著名的U-Net,基本上是由encoder和decoder组成的。我的理解是encoder将图像的信息提取出来,decoder将提取到的信息转化成任务需要的结果(比如分割图、分类)。
为了方便后续比较不同encoder等模块的效果,我打算将常见类型的模块都实现一遍,并做好整理。
今天看了一下ViT模型,它是第一个将注意力机制应用到图像上的工作。原文章的任务是二维图像分类,我根据自己的需求,将这个网络中的encoder模块抽取出来,用于3D医学影像分割。
需要注意的是,我只需要ViT中的transformer Encoder
模块(当然Patch embedding
模块也是附带的),因为原网络的分类任务的其他附加模块我都用不上,咱只拿对自己有用的。
ViT的具体网络见下图,来源:使用pytorch搭建Vision Transformer(vit)模型
代码实现
下面将根据流程,分别实现每部分代码,并且记录一下我的理解。
Patch Embedding
我们都很熟悉CNN在图像任务中的作用,它可以把图像的特征提取为二维矩阵。就像UNet,这个网络全都是基于CNN实现的。这样的好处是,每次卷积操作,都能保证图像特征是二维矩阵,方便后续恢复图像。
但是自注意力代码中,完全没有卷积的影子。自注意力机制一开始是用在NLP领域的,我们知道在做NLP任务之前,需要将文字(英文单词或者中文汉字)转换成编码,也就是embedding。文字的编码是一个一维的向量(或者叫tensor),而不是二维的矩阵。
现在就存在个问题,我们怎么才能把一张图像,转换成像文字那样的一维编码?并且文字是连续的,图像只能是单张的。
Patch Embedding的做法是:(1)将一张图像分成多块,这样就模拟了连续的一串文字。(2)对于每块图像,同样需要使用CNN提取特征,但是特殊之处在于,提取出的特征在长宽高维度进行展平,这样就得到了每块图像的一维特征了。
下面是具体代码:
- 对于输入的tensor格式
[B, C, D, H, W]
的图像数据,用三维卷积提取特征 - 提取之后的特征,在DHW维度上展平,相当于得到了图像块序列
- 然后在图像块序列上添加可学习的位置编码
class PatchEmbed(nn.Module):
"""
3D医学图像的patch embedding
"""
def __init__(self, img_size: Tuple[int, int, int], patch_size: Tuple[int, int, int]
, in_channel: int, embed_dim: int, norm_layer=None):
"""
:param img_size: 三维医学图像大小[D, H, W]
:param patch_size: 分成patch的每个维度大小
:param in_channel: 三维医学图像的channel数
:param embed_dim: 进行embedding之后的channel数
:param norm_layer: 是否使用norm层
"""
super().__init__()
assert (len(img_size) == 3)
self.img_size = img_size
assert (len(patch_size) == 3)
self.patch_size = patch_size
self.in_channel = in_channel
self.embed_dim = embed_dim
# 对三维图像取patch之后,每个维度的patch个数
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1], img_size[2] // patch_size[2])
# 一共有多少个patch
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
# 卷积层,提取每个patch的特征
self.project = nn.Conv3d(in_channels=in_channel, out_channels=embed_dim, kernel_size=patch_size,
stride=patch_size)
# 判断是否要norm层
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
# 可学习的位置编码,初始化为0
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, self.embed_dim))
def forward(self, x: torch.Tensor):
"""
:param x: 传入的3D医学影像数据tensor格式[B, C, D, H, W]
:return: embedding之后的结果 [B, N, C],N表示token的个数,C表示每个token的维度
"""
B, C, D, H, W = x.shape
assert D == self.img_size[0] and H == self.img_size[1] and W == self.img_size[2] and C == self.in_channel
# 得到每个patch的特征图
x = self.project(x)
# 将DHW维度展平
x = x.flatten(2)
# 调换展平维度和channel维度
x = x.transpose(1, 2)
# norm层
x = self.norm(x)
# 添加位置编码,模拟文字的顺序,这样的效果更好
x += self.pos_embed
return x
Attention模块
这是一个多头自注意力模块,多头的策略是直接将特征向量进行均分。
这部分内容比较复杂,涉及到QKV的理解。我对这部分的理解还不够深,只能做到如下意会:
- query相当于我们对图像块中感兴趣的内容,把它定义为查询
- key相当于图像块中提供给外部的可查询信息,给定一个query和一个key,就能计算出这两者的相似程度,也就是权重
- value相当于图像块中对我们任务有直接帮助的信息
当前图像块对任务最终结果的贡献度,是value的加权和。
Attention代码的流程比较固定,就直接贴代码了。(还是我太菜了)需要注意的一点是,特征输入到注意力模块后的结果,维度不会发生变化。
class Attention(nn.Module):
"""
attention 模块
"""
def __init__(self, dim: int, num_heads: int, qkv_bias: bool = False,
qk_scale=None, attn_drop_ratio: float = 0., project_drop_ratio: float = 0.):
"""
:param dim: 输入token的dimension
:param num_heads: 多头注意力
:param qkv_bias: 生成qkv的时候是否要使用偏置
:param qk_scale: qk相乘得到权重之后,是否需要进行缩放
:param attn_drop_ratio:
:param project_drop_ratio:
"""
super(Attention, self).__init__()
self.dim = dim
self.num_heads = num_heads
# 每个头的dim直接平分
head_dim = dim // num_heads
self.head_dim = head_dim
self.scale = qk_scale or head_dim ** -0.5
# 使用一个MLP计算qkv矩阵
self.qkv = nn.Linear(in_features=dim, out_features=dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_ratio)
# 多头注意力需要拼接多头的结果,还需要一个MLP转换层
self.project = nn.Linear(in_features=dim, out_features=dim)
self.project_drop = nn.Dropout(project_drop_ratio)
def forward(self, x: torch.Tensor):
"""
:param x: 输入经过patch embedding的结果[B, N, C]
:return: 经过多头注意力机制得到的结果 [B, N, C]
"""
# C就是token的维度
B, N, C = x.shape
assert C == self.dim
# 最后得到 [3, B, heads, N, head_dim]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
# [B, heads, N, head_dim]
q, k, v = qkv[0], qkv[1], qkv[2]
# 转置之后,矩阵乘法,只会操作最后两个维度
# [B, heads, N, N]
# 也就是每个query和每个key的权重
attn = (q @ k.transpose(-2, -1)) * self.scale
# 使用softmax处理,对最后一行进行处理
attn = attn.softmax(dim=-1)
# dropout
attn = self.attn_drop(attn)
# 将value与权重加权求和
# 得到多头注意力拼接之后的结果
# [B, N, C]
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
# 最后再将结果传入MLP转化
x = self.project(x)
x = self.project_drop(x)
return x
Mlp Block
在一个基本的Encoder Block中,除了多头注意力机制,还有一个Mlp Block。这个模块的作用,我现在也不能理解,就先这么用吧。它的结构比较简单,直接贴代码:
class Mlp(nn.Module):
"""
ViT中的MLP层
"""
def __init__(self, in_features: int, hidden_features: int = None,
out_features: int = None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features if out_features else in_features
hidden_features = hidden_features if hidden_features else in_features
self.fc1 = nn.Linear(in_features=in_features, out_features=hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(in_features=hidden_features, out_features=out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
"""
:param x:
:return:
"""
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
EncoderBlock
接下来实现基础的encoder block。按照上图中展示的结构直接搭建就行了。代码中使用了drop_path,据说比dropout更好,直接用了。
class EncoderBlock(nn.Module):
"""
这是一个基本的encoder block
"""
def __init__(self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.,
qkv_bias: bool = False,
qk_scale: float = None,
drop_ratio: float = 0.,
attn_drop_ratio: float = 0.,
drop_path_ratio: float = 0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super(EncoderBlock, self).__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop_ratio=attn_drop_ratio, project_drop_ratio=drop_ratio)
self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def drop_path(self, x, drop_prob: float = 0., training: bool = False):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
def forward(self, x):
return self.drop_path(x, self.drop_prob, self.training)
TransformerEncoder
最后终于来到了我们的重头戏,TransformerEncoder其实就是多个EncoderBlock的叠加,因为特征向量输入EncoderBlock后,输出的特征向量维度完全没有改变,所以直接串联起来就好了,参数都不需要改!(当然,我在看别人的代码的时候,发现随着EncoderBlock的叠加,drop out rate的设置是递减的,我就先不这么设置了)
在forward函数中,我们选择了要传出的中间特征的个数,这样对于像unet这样需要融合前一阶段特征的模型来说很方便。
class TransformerEncoder(nn.Module):
"""
这里就是多个base encoder的串联
"""
def __init__(self,
img_size: Tuple[int, int, int],
patch_size: Tuple[int, int, int],
in_channel: int,
embed_dim: int,
num_heads: int,
out_branches: List[int], # 表示哪些EncoderBlock的结果需要被输出
depth: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_scale: float = None,
drop_ratio: float = 0.,
attn_drop_ratio: float = 0.,
drop_path_ratio: float = 0.,
norm_layer=None,
act_layer=None):
super(TransformerEncoder, self).__init__()
# 判断out_branches范围有效
if len(out_branches) == 0:
out_branches = [depth - 1]
for branch in out_branches:
assert(branch in range(depth))
self.out_branches = out_branches
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_channel=in_channel,
embed_dim=embed_dim, norm_layer=norm_layer)
self.encodeBlocks = \
[
EncoderBlock(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
qk_scale=qk_scale, drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio,
drop_path_ratio=drop_path_ratio, act_layer=act_layer, norm_layer=norm_layer)
for i in range(depth)
]
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""
:param x: 输入图像的tensor [B, C, D, H, W]
:return:
"""
out = []
x = self.patch_embed(x)
for id, encodeBlock in enumerate(self.encodeBlocks):
x = encodeBlock(x)
if id in self.out_branches:
out.append(x)
return out
测试
最后我们来测试一下整体模块的正确性,主要是看输出特征的维度是否正确。
if __name__ == '__main__':
# 创建一个encoder
transformerEncoder = TransformerEncoder(img_size=(128, 256, 256), patch_size=(32, 64, 64), in_channel=1,
embed_dim=32, num_heads=8, out_branches=[0, 4, 8, 11])
test = torch.rand((16, 1, 128, 256, 256))
result = transformerEncoder(test)
for out in result:
print(out.shape)
我们选择输出四个特征层,维度都是正确的
torch.Size([16, 64, 32])
torch.Size([16, 64, 32])
torch.Size([16, 64, 32])
torch.Size([16, 64, 32])
之后打算搭建一个完整的网络,测试这个transformer encode的效果。