【encoder模块】transformer_encoder编写

介绍

最近打算跑一些3D医学影像分割的实验,并且做一些记录。
分割网络,像著名的U-Net,基本上是由encoder和decoder组成的。我的理解是encoder将图像的信息提取出来,decoder将提取到的信息转化成任务需要的结果(比如分割图、分类)。

为了方便后续比较不同encoder等模块的效果,我打算将常见类型的模块都实现一遍,并做好整理。

今天看了一下ViT模型,它是第一个将注意力机制应用到图像上的工作。原文章的任务是二维图像分类,我根据自己的需求,将这个网络中的encoder模块抽取出来,用于3D医学影像分割。

需要注意的是,我只需要ViT中的transformer Encoder模块(当然Patch embedding模块也是附带的),因为原网络的分类任务的其他附加模块我都用不上,咱只拿对自己有用的。

ViT的具体网络见下图,来源:使用pytorch搭建Vision Transformer(vit)模型
在这里插入图片描述

代码实现

下面将根据流程,分别实现每部分代码,并且记录一下我的理解。

Patch Embedding

我们都很熟悉CNN在图像任务中的作用,它可以把图像的特征提取为二维矩阵。就像UNet,这个网络全都是基于CNN实现的。这样的好处是,每次卷积操作,都能保证图像特征是二维矩阵,方便后续恢复图像。

但是自注意力代码中,完全没有卷积的影子。自注意力机制一开始是用在NLP领域的,我们知道在做NLP任务之前,需要将文字(英文单词或者中文汉字)转换成编码,也就是embedding。文字的编码是一个一维的向量(或者叫tensor),而不是二维的矩阵。

现在就存在个问题,我们怎么才能把一张图像,转换成像文字那样的一维编码?并且文字是连续的,图像只能是单张的。

Patch Embedding的做法是:(1)将一张图像分成多块,这样就模拟了连续的一串文字。(2)对于每块图像,同样需要使用CNN提取特征,但是特殊之处在于,提取出的特征在长宽高维度进行展平,这样就得到了每块图像的一维特征了。

下面是具体代码:

  • 对于输入的tensor格式[B, C, D, H, W]的图像数据,用三维卷积提取特征
  • 提取之后的特征,在DHW维度上展平,相当于得到了图像块序列
  • 然后在图像块序列上添加可学习的位置编码
class PatchEmbed(nn.Module):
    """
    3D医学图像的patch embedding
    """

    def __init__(self, img_size: Tuple[int, int, int], patch_size: Tuple[int, int, int]
                 , in_channel: int, embed_dim: int, norm_layer=None):
        """

        :param img_size: 三维医学图像大小[D, H, W]
        :param patch_size: 分成patch的每个维度大小
        :param in_channel: 三维医学图像的channel数
        :param embed_dim: 进行embedding之后的channel数
        :param norm_layer: 是否使用norm层
        """
        super().__init__()
        assert (len(img_size) == 3)
        self.img_size = img_size
        assert (len(patch_size) == 3)
        self.patch_size = patch_size

        self.in_channel = in_channel
        self.embed_dim = embed_dim

        # 对三维图像取patch之后,每个维度的patch个数
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1], img_size[2] // patch_size[2])
        # 一共有多少个patch
        self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]

        # 卷积层,提取每个patch的特征
        self.project = nn.Conv3d(in_channels=in_channel, out_channels=embed_dim, kernel_size=patch_size,
                                 stride=patch_size)

        # 判断是否要norm层
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

        # 可学习的位置编码,初始化为0
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, self.embed_dim))

    def forward(self, x: torch.Tensor):
        """

        :param x: 传入的3D医学影像数据tensor格式[B, C, D, H, W]
        :return: embedding之后的结果 [B, N, C],N表示token的个数,C表示每个token的维度
        """
        B, C, D, H, W = x.shape
        assert D == self.img_size[0] and H == self.img_size[1] and W == self.img_size[2] and C == self.in_channel

        # 得到每个patch的特征图
        x = self.project(x)
        # 将DHW维度展平
        x = x.flatten(2)
        # 调换展平维度和channel维度
        x = x.transpose(1, 2)

        # norm层
        x = self.norm(x)

        # 添加位置编码,模拟文字的顺序,这样的效果更好
        x += self.pos_embed

        return x

Attention模块

这是一个多头自注意力模块,多头的策略是直接将特征向量进行均分。

这部分内容比较复杂,涉及到QKV的理解。我对这部分的理解还不够深,只能做到如下意会:

  • query相当于我们对图像块中感兴趣的内容,把它定义为查询
  • key相当于图像块中提供给外部的可查询信息,给定一个query和一个key,就能计算出这两者的相似程度,也就是权重
  • value相当于图像块中对我们任务有直接帮助的信息

当前图像块对任务最终结果的贡献度,是value的加权和。

Attention代码的流程比较固定,就直接贴代码了。(还是我太菜了)需要注意的一点是,特征输入到注意力模块后的结果,维度不会发生变化。

class Attention(nn.Module):
    """
    attention 模块
    """

    def __init__(self, dim: int, num_heads: int, qkv_bias: bool = False,
                 qk_scale=None, attn_drop_ratio: float = 0., project_drop_ratio: float = 0.):
        """

        :param dim: 输入token的dimension
        :param num_heads: 多头注意力
        :param qkv_bias: 生成qkv的时候是否要使用偏置
        :param qk_scale: qk相乘得到权重之后,是否需要进行缩放
        :param attn_drop_ratio:
        :param project_drop_ratio:
        """

        super(Attention, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        # 每个头的dim直接平分
        head_dim = dim // num_heads
        self.head_dim = head_dim
        self.scale = qk_scale or head_dim ** -0.5

        # 使用一个MLP计算qkv矩阵
        self.qkv = nn.Linear(in_features=dim, out_features=dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)

        # 多头注意力需要拼接多头的结果,还需要一个MLP转换层
        self.project = nn.Linear(in_features=dim, out_features=dim)
        self.project_drop = nn.Dropout(project_drop_ratio)

    def forward(self, x: torch.Tensor):
        """

        :param x: 输入经过patch embedding的结果[B, N, C]
        :return: 经过多头注意力机制得到的结果 [B, N, C]
        """
        # C就是token的维度
        B, N, C = x.shape

        assert C == self.dim

        # 最后得到 [3, B, heads, N, head_dim]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)

        # [B, heads, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]

        # 转置之后,矩阵乘法,只会操作最后两个维度
        # [B, heads, N, N]
        # 也就是每个query和每个key的权重
        attn = (q @ k.transpose(-2, -1)) * self.scale

        # 使用softmax处理,对最后一行进行处理
        attn = attn.softmax(dim=-1)

        # dropout
        attn = self.attn_drop(attn)

        # 将value与权重加权求和
        # 得到多头注意力拼接之后的结果
        # [B, N, C]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)

        # 最后再将结果传入MLP转化
        x = self.project(x)
        x = self.project_drop(x)

        return x

Mlp Block

在一个基本的Encoder Block中,除了多头注意力机制,还有一个Mlp Block。这个模块的作用,我现在也不能理解,就先这么用吧。它的结构比较简单,直接贴代码:

class Mlp(nn.Module):
    """
    ViT中的MLP层
    """

    def __init__(self, in_features: int, hidden_features: int = None,
                 out_features: int = None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features if out_features else in_features
        hidden_features = hidden_features if hidden_features else in_features
        self.fc1 = nn.Linear(in_features=in_features, out_features=hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(in_features=hidden_features, out_features=out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        """

        :param x:
        :return:
        """
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

EncoderBlock

接下来实现基础的encoder block。按照上图中展示的结构直接搭建就行了。代码中使用了drop_path,据说比dropout更好,直接用了。

class EncoderBlock(nn.Module):
    """
    这是一个基本的encoder block
    """

    def __init__(self,
                 dim: int,
                 num_heads: int,
                 mlp_ratio: float = 4.,
                 qkv_bias: bool = False,
                 qk_scale: float = None,
                 drop_ratio: float = 0.,
                 attn_drop_ratio: float = 0.,
                 drop_path_ratio: float = 0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super(EncoderBlock, self).__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, project_drop_ratio=drop_ratio)
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class DropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def drop_path(self, x, drop_prob: float = 0., training: bool = False):
        """
        Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
        This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
        the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
        See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
        changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
        'survival rate' as the argument.
        """
        if drop_prob == 0. or not training:
            return x
        keep_prob = 1 - drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()  # binarize
        output = x.div(keep_prob) * random_tensor
        return output

    def forward(self, x):
        return self.drop_path(x, self.drop_prob, self.training)

TransformerEncoder

最后终于来到了我们的重头戏,TransformerEncoder其实就是多个EncoderBlock的叠加,因为特征向量输入EncoderBlock后,输出的特征向量维度完全没有改变,所以直接串联起来就好了,参数都不需要改!(当然,我在看别人的代码的时候,发现随着EncoderBlock的叠加,drop out rate的设置是递减的,我就先不这么设置了)

在forward函数中,我们选择了要传出的中间特征的个数,这样对于像unet这样需要融合前一阶段特征的模型来说很方便。

class TransformerEncoder(nn.Module):
    """
    这里就是多个base encoder的串联
    """

    def __init__(self,
                 img_size: Tuple[int, int, int],
                 patch_size: Tuple[int, int, int],
                 in_channel: int,
                 embed_dim: int,
                 num_heads: int,
                 out_branches: List[int],  # 表示哪些EncoderBlock的结果需要被输出
                 depth: int = 12,
                 mlp_ratio: float = 4.0,
                 qkv_bias: bool = True,
                 qk_scale: float = None,
                 drop_ratio: float = 0.,
                 attn_drop_ratio: float = 0.,
                 drop_path_ratio: float = 0.,
                 norm_layer=None,
                 act_layer=None):
        super(TransformerEncoder, self).__init__()

        # 判断out_branches范围有效
        if len(out_branches) == 0:
            out_branches = [depth - 1]

        for branch in out_branches:
            assert(branch in range(depth))

        self.out_branches = out_branches

        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_channel=in_channel,
                                      embed_dim=embed_dim, norm_layer=norm_layer)

        self.encodeBlocks = \
            [
                EncoderBlock(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                             qk_scale=qk_scale, drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio,
                             drop_path_ratio=drop_path_ratio, act_layer=act_layer, norm_layer=norm_layer)
                for i in range(depth)
            ]


    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        """

        :param x: 输入图像的tensor [B, C, D, H, W]
        :return:
        """
        out = []
        x = self.patch_embed(x)

        for id, encodeBlock in enumerate(self.encodeBlocks):
            x = encodeBlock(x)
            if id in self.out_branches:
                out.append(x)

        return out

测试

最后我们来测试一下整体模块的正确性,主要是看输出特征的维度是否正确。

if __name__ == '__main__':
    # 创建一个encoder
    transformerEncoder = TransformerEncoder(img_size=(128, 256, 256), patch_size=(32, 64, 64), in_channel=1,
                                            embed_dim=32, num_heads=8, out_branches=[0, 4, 8, 11])

    test = torch.rand((16, 1, 128, 256, 256))
    result = transformerEncoder(test)

    for out in result:
        print(out.shape)

我们选择输出四个特征层,维度都是正确的
torch.Size([16, 64, 32])
torch.Size([16, 64, 32])
torch.Size([16, 64, 32])
torch.Size([16, 64, 32])

之后打算搭建一个完整的网络,测试这个transformer encode的效果。

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在这段代码中,有几个问题需要进行修正。 首先,你定义了一个名为 `FCNTransformerNet` 的类,但是没有在代码末尾实例化这个类。你需要添加 `()` 来实例化该类,将其赋值给变量 `net`。 其次,你需要将 `FCNTransformerNet` 类的定义放在 `if __name__ == "__main__":` 条件语句内部,以确保它只在主程序中运行,而不是在模块导入时运行。 最后,你的代码中缺少了一些必要的导入语句。你需要导入 `torch`, `torch.nn` 和 `torchvision.models`。 下面是修改后的代码: ```python import torch import torch.nn as nn import torchvision.models as models class FCNTransformerNet(nn.Module): def __init__(self, num_classes): super(FCNTransformerNet, self).__init__() self.fcn_backbone = models.segmentation.fcn_resnet50(pretrained=True).backbone self.fcn_backbone.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.transformer_layers = nn.TransformerEncoderLayer(d_model=2048, nhead=8) self.transformer_encoder = nn.TransformerEncoder(self.transformer_layers, num_layers=6) self.classification_head = nn.Sequential( nn.Linear(2048, 512), nn.ReLU(), nn.Linear(512, num_classes) ) def forward(self, x): fcn_output = self.fcn_backbone(x)['out'] fcn_output = fcn_output.view(fcn_output.size(0), fcn_output.size(1), -1) fcn_output = fcn_output.permute(2, 0, 1) transformer_output = self.transformer_encoder(fcn_output) transformer_output = transformer_output.permute(1, 2, 0) transformer_output = transformer_output.contiguous().view(transformer_output.size(0), -1, 1, 1) output = self.classification_head(transformer_output) return output if __name__ == "__main__": net = FCNTransformerNet(num_classes=2) input_batch = torch.randn(4, 3, 512, 512) output_batch = net(input_batch) print(output_batch.size()) # Should print: torch.Size([4, 2, 512, 512]) ``` 请注意,这段代码假设你已经正确安装了 `torch` 和 `torchvision` 库。如果出现任何错误,你可能需要检查这些库的安装情况。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值