分割模型TransNetR的pytorch代码学习笔记

这个模型在U-net的基础上融合了Transformer模块和残差网络的原理。

论文地址:https://arxiv.org/pdf/2303.07428.pdf

具体的网络结构如下:

网络的原理还是比较简单的,

编码分支用的是预训练的resnet模块,解码分支则重新设计了。

解码器分支的模块结构示意图如下:

可以看出来,就是Transformer模块和残差连接相加,然后再经过一个residual模块处理。

1,用pytorch实现时,首先要把这个解码器模块实现出来:

class residual_transformer_block(nn.Module):
    def __init__(self, in_c, out_c, patch_size=4, num_heads=4, num_layers=2, dim=None):
        super().__init__()

        self.ps = patch_size
        self.c1 = Conv2D(in_c, out_c)

        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads)
        self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.c2 = Conv2D(out_c, out_c, kernel_size=1, padding=0, act=False)
        self.c3 = Conv2D(in_c, out_c, kernel_size=1, padding=0, act=False)
        self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.r1 = residual_block(out_c, out_c)

    def forward(self, inputs):
        x = self.c1(inputs)

        b, c, h, w = x.shape
        num_patches = (h*w)//(self.ps**2)
        x = torch.reshape(x, (b, (self.ps**2)*c, num_patches))
        x = self.te(x)
        x = torch.reshape(x, (b, c, h, w))

        x = self.c2(x)
        s = self.c3(inputs)
        x = self.relu(x + s)
        x = self.r1(x)
        return x

其中我们需要注意的是这里构建Transformer块的方法,也就是下面两句:

encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads)
self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

首先,第一句是用nn.TransformerEncoderLayer定义了一个Transformer层,并存储在encoder_layer变量中。

nn.TransformerEncoderLayer的参数包括:d_model(输入特征的维度大小),nhead(自注意力机制中注意力头数量),dim_feedforward(前馈网络的隐藏层维度大小),dropout(dropout比例),apply(用于在编码器层及其子层上应用函数,例如初始化或者权重共享等功能)。

第二句则是将多个Transformer层堆叠在一起,构建一个Transformer编码器块。

nn.TransformerEncoder的参数包括:encoder_layer(用于构建模块的每个Transformer层),num_layer(堆叠的层数),norm(执行的标准化方法),apply(同上)。

2,在上面的解码器模块中,还有一个residual block需要额外实现,如下:

class residual_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c)
        )
        self.shortcut = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)

    def forward(self, inputs):
        x = self.conv(inputs)
        s = self.shortcut(inputs)
        return self.relu(x + s)

这个代码就是简单的残差卷积模块,不赘述。

3,重要的模块实现完了,接下来就是按照UNet的形状拼装网络,代码如下:

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        """ Encoder """
        backbone = resnet50()
        self.layer0 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu)
        self.layer1 = nn.Sequential(backbone.maxpool, backbone.layer1)
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4

        self.e1 = Conv2D(64, 64, kernel_size=1, padding=0)
        self.e2 = Conv2D(256, 64, kernel_size=1, padding=0)
        self.e3 = Conv2D(512, 64, kernel_size=1, padding=0)
        self.e4 = Conv2D(1024, 64, kernel_size=1, padding=0)


        """ Decoder """
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.r1 = residual_transformer_block(64+64, 64, dim=64)
        self.r2 = residual_transformer_block(64+64, 64, dim=256)
        self.r3 = residual_block(64+64, 64)

        """ Classifier """
        self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, inputs):
        """ Encoder """
        x0 = inputs
        x1 = self.layer0(x0)    ## [-1, 64, h/2, w/2]
        x2 = self.layer1(x1)    ## [-1, 256, h/4, w/4]
        x3 = self.layer2(x2)    ## [-1, 512, h/8, w/8]
        x4 = self.layer3(x3)    ## [-1, 1024, h/16, w/16]

        e1 = self.e1(x1)
        e2 = self.e2(x2)
        e3 = self.e3(x3)
        e4 = self.e4(x4)

        """ Decoder """
        x = self.up(e4)
        x = torch.cat([x, e3], axis=1)
        x = self.r1(x)

        x = self.up(x)
        x = torch.cat([x, e2], axis=1)
        x = self.r2(x)

        x = self.up(x)
        x = torch.cat([x, e1], axis=1)
        x = self.r3(x)

        x = self.up(x)

        """ Classifier """
        outputs = self.outputs(x)
        return outputs

其中,x1,x2,x3,x4就是编码器模块,用的都是resnet50的预训练模块。

其中r1,r2,r3,r4则是解码器的模块,就是上面实现的模块。

而e1,e2,e3,e4则是在skip connection前给编码器的输出做1x1卷积,作用大体上就是减少计算量。

完整代码:https://github.com/DebeshJha/TransNetR/blob/main/model.py#L45

  • 11
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

蓝海渔夫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值