TransUNet代码解读

vit_seg_modeling.py

Encoder类的主要功能是将输入序列通过一系列堆叠的编码层进行处理,每一层都包含自注意力机制和前馈神经网络结构,以此提取和融合上下文信息。最后,使用LayerNorm层对编码结果进行规范化,确保输出尺度一致性,方便后续解码器或其他模块使用。如果开启可视化,还会记录各层的注意力权重,以供后续分析。

class Encoder(nn.Module):
    def __init__(self, config, vis):
        super(Encoder, self).__init__()  # 调用父类(nn.Module)的构造函数,初始化Encoder
        self.vis = vis  # 保存可视化标志,用于在训练过程中可视化注意力权重
        self.layer = nn.ModuleList()  # 创建一个模块列表moduleList容器,用于存储编码器中的所有编码块
        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)  # 创建一个全局层归一化(LayerNorm),用于编码器的输出

        # 循环创建指定数量的编码块,并添加到模块列表
        #copy.deepcopy(layer) 使用深拷贝技术创建Block实例的一个副本,并将其添加到self.layer列表中。深拷贝确保每个Block实例都是独立的,互不影响各自的权重和其他属性。
        for _ in range(config.transformer["num_layers"]):
            layer = Block(config, vis)
            self.layer.append(copy.deepcopy(layer))
    def forward(self, hidden_states):
        attn_weights = []  # 创建一个列表,用于存储每层编码块产生的注意力权重
        # 遍历所有编码块,并将输入传递给它们
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)  # 对输入进行编码,并获取输出和注意力权重
            if self.vis:  # 如果需要可视化注意力权重
                attn_weights.append(weights)  # 将注意力权重添加到列表中

        encoded = self.encoder_norm(hidden_states)  # 应用全局层归一化
        return encoded, attn_weights  # 返回编码后的输出和注意力权重列表

Transformer类是Vision Transformer (ViT) 模型的核心组件,它整合了嵌入层、编码器和解码器,用于处理图像数据。以下是Transformer类中每行代码的解释:

class Transformer(nn.Module):
    def __init__(self, config, img_size, vis):
        super(Transformer, self).__init__()  # 调用基类nn.Module的构造函数
        self.embeddings = Embeddings(config, img_size=img_size)  # 创建嵌入层,用于生成图像的嵌入表示
        self.encoder = Encoder(config, vis)  # 创建编码器,由多个编码块组成
        self.decoder = DecoderCup(config)  # 创建解码器,用于将编码器的输出转换为图像特征
        self.segmentation_head = SegmentationHead(
            in_channels=config['decoder_channels'][-1],
            out_channels=config['n_classes'],
            kernel_size=3,
        )  # 创建分割头,用于生成最终的分割图
        self.config = config  # 保存配置信息
    def forward(self, input_ids):
        if x.size()[1] == 1:
            x = x.repeat(1,3,1,1)  # 如果输入只有一个通道,复制通道以匹配模型要求
        embedding_output, features = self.embeddings(input_ids)  # 通过嵌入层获取嵌入表示和特征
        encoded, attn_weights, features = self.encoder(embedding_output, features)  # 通过编码器获取编码特征和注意力权重
        x = self.decoder(encoded, features)  # 通过解码器获取解码特征
        logits = self.segmentation_head(x)  # 通过分割头获取分割结果
        return logits  # 返回分割结果

https://blog.csdn.net/qq_41813454/article/details/129928484
Conv2dReLU类是将卷积层(Conv2d)与ReLU激活函数(ReLU)以及可选的批量归一化层(BatchNorm2d)封装在一起的组合层。下面是该类每行代码的作用解析:

class Conv2dReLU(nn.Sequential):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=1,
            use_batchnorm=True,
    ):
    #创建一个二维卷积层,其参数根据传入的值进行设置。如果use_batchnorm为True,则卷积层自动禁用偏置项(bias=False),因为批量归一化层将接手这部分功能。
        conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )
        relu = nn.ReLU(inplace=True)#创建一个ReLU激活函数,设置为inplace模式,这意味着ReLU操作将直接修改输入张量,而非创建新的张量。

        bn = nn.BatchNorm2d(out_channels)#如果use_batchnorm为True,则实例化一个二维批量归一化层,其通道数等于输出通道数
        super(Conv2dReLU, self).__init__(conv, bn, relu)#将卷积层、批量归一化层(如果启用)和ReLU激活函数按照顺序添加到nn.Sequential容器中,这样当调用这个类的实例的forward方法时,数据会依次经过这三个层。
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels=0, use_batchnorm=True):
        super(DecoderBlock, self).__init__()

        # 初始化一个上采样模块,如双线性插值或者子像素卷积(Transposed Convolution)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        # 初始化一个包含卷积、BatchNorm(如果use_batchnorm为True)和ReLU激活层的组合层
        self.conv1 = Conv2dReLU(
            in_channels=in_channels + skip_channels,
            out_channels=out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )

        # 初始化第二个类似的组合层
        self.conv2 = Conv2dReLU(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )

    def forward(self, x, skip=None):
        # 上采样输入特征
        x = self.up(x)

        # 如果有跳过连接(skip connection)的特征,将其与上采样的特征拼接
        if skip is not None:
            x = torch.cat([x, skip], dim=1)

        # 将拼接后的特征通过第一个卷积-ReLU层
        x = self.conv1(x)

        # 再将经过第一层处理后的特征通过第二个卷积-ReLU层
        x = self.conv2(x)

        # 返回处理后的特征
        return x

DecoderBlock类的作用是实现特征上采样和特征融合的过程,它包含以下功能:

使用双线性插值等方法对输入特征进行上采样,提升特征图的空间分辨率。
如果存在来自编码器的跳过连接(skip connection)特征,则将上采样的特征与编码器的浅层特征进行拼接,以融合低层的细节信息和高层的语义信息。
使用包含卷积、BatchNorm(如果use_batchnorm设为True)和ReLU激活函数的组合层对拼接后的特征进行处理,这个过程一般重复两次,以增强特征表达能力和网络的非线性性质。
forward方法中描述了特征上采样、特征融合、以及通过两个卷积-ReLU组合层的完整前向传播流程。最终返回处理后的特征图,用于后续解码器层或输出层进行处理。

SegmentationHead类是Vision Transformer (ViT) 模型中的一个组件,用于将解码器的输出转换为最终的分割图。以下是SegmentationHead类中每行代码的解释:

class SegmentationHead(nn.Sequential):
    # 在类的初始化方法中定义了分割头的构成
    def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
        # nn.Sequential()是一个容器,可以按照顺序包含多个模块(或层)
        # 这里它被用来方便地堆叠两个操作:一个卷积层后跟一个上采样层
        conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
        # 卷积层用于将解码器的输出映射到最终的类别数量
        # kernel_size定义了卷积核的大小,padding通常是kernel_size的一半,以避免尺寸减小
        upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
        # 如果upsampling的值大于1,则使用双线性上采样层来增加输出的分辨率
        # 如果upsampling的值是1或更小,则不进行上采样,使用Identity模块,即直接传递输入
        super().__init__(conv2d, upsampling)
        # 调用父类nn.Sequential的初始化方法,
        #将前面定义的卷积层conv2d和上采样层(或者恒等映射层)upsampling按照顺序添加到Sequential模块中。
        #这样,在调用模型时,输入会先通过卷积层,然后根据设定条件决定是否进行上采样操作。

DecoderCup类是Vision Transformer (ViT) 模型中的解码器组件,它的作用是将编码器的输出转换回图像的空间维度,以便进行像素级的分割任务。DecoderCup的主要组成部分包括:

卷积层:用来增加非线性表达能力以及对特征进行进一步处理。
上采样层:通过Cascaded Upsampler(CUP)结构实现特征分辨率的逐级提升,CUP由多个连续的上采样块组成,每个块包括2倍上采样操作、3x3卷积层以及ReLU激活函数。
跳过连接(Skip Connections):DecoderCup形成了一个类似于U-Net的U形结构,允许低层次的高分辨率特征与经过Transformer编码器和解码器处理的高层特征相结合,这种跨层次的信息融合有助于提升分割精度。

DecoderCup首先将Transformer编码器输出的特征进行上采样和重构,并通过卷积层进行特征提取和增强。随后,DecoderCup利用CUP结构逐级提升特征分辨率,并在此过程中结合从编码器阶段获取的高分辨率特征,通过跳过连接传递到相应的解码层,以实现精细的空间定位能力。最终,经过DecoderCup处理的特征被送入SegmentationHead模块,生成最终的像素级分割结果。

class DecoderCup(nn.Module):
    def __init__(self, config):
        super().__init__()  # 调用基类nn.Module的构造函数
        self.config = config  # 保存模型配置
        head_channels = 512  # 定义解码器头部的通道数,这里假设为512.用于计算第一个解码块的输入通道数。
        self.conv_more = Conv2dReLU(  # 创建一个额外的卷积+ReLU层,它包含卷积、批量归一化和ReLU激活函数,用于将Transformer编码器输出的特征转换为适合解码器的特征表示。
            config.hidden_size,
            head_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=True,
        )
        decoder_channels = config.decoder_channels  # 获取解码器通道配置
        in_channels = [head_channels] + list(decoder_channels[:-1])  # 定义输入通道列表
        out_channels = decoder_channels  # 定义输出通道列表

        if self.config.n_skip != 0:  # 如果模型使用了跳跃连接
            skip_channels = self.config.skip_channels  # 获取要跳过的通道配置
            # 根据n_skip重新选择跳跃通道,以便与解码器块匹配
            for i in range(4-self.config.n_skip):
                skip_channels[3-i]=0
        else:
            skip_channels=[0,0,0,0]  # 如果没有使用跳跃连接,则设置为0

        blocks = [  # 构建一系列DecoderBlock模块,每个模块由几个卷积层、上采样层组成,并且可以接收跳过连接(skip connections)。即使用列表推导式创建多个DecoderBlock实例,每个实例代表解码器中的一个上采样块,并将它们添加到ModuleList中。
            DecoderBlock(in_ch, out_ch, sk_ch)
            for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
        ] #zip()函数用于将两个列表相对应位置的元素打包成元组。
        self.blocks = nn.ModuleList(blocks)  # 将这些解码块以ModuleList的形式组织在一起,便于在前向传播中循环调用。

in_channels = [head_channels] + list(decoder_channels[:-1]):
in_channels 是一个列表,它定义了解码器中每个解码器块(DecoderBlock)的输入通道数。这个列表的构造方式如下:

[head_channels]:首先,列表包含一个元素,即 head_channels,它表示解码器的第一个块(紧接额外卷积层 conv_more 之后)的输入通道数。head_channels 是在 DecoderCup 类的初始化中定义的,通常是根据模型设计而选择的一个值,例如 512。

list(decoder_channels[:-1]):然后,使用列表推导式和切片操作 [:-1] 来获取 decoder_channels 列表中除了最后一个元素之外的所有元素,并将它们转换为一个新的列表。decoder_channels 是模型配置中定义的解码器各层的通道数。切片操作 [:-1] 用于获取除最后一层外的所有层的通道数,因为最后一层的通道数通常用于输出,而不是作为下一层的输入。

将这两部分合并,in_channels 就包含了解码器中每个块的输入通道数,从第一个块开始,直到倒数第二个块。这是因为最后一个块的输入通道数通常是第二个最后一个块的输出通道数,所以不需要显式地包含在 in_channels 列表中。这种设计允许模型在解码器中逐步降低特征图的空间维度,同时增加通道数,以捕捉更丰富的空间信息

    if self.config.n_skip != 0:
        skip_channels = self.config.skip_channels
        for i in range(4-self.config.n_skip):
            skip_channels[3-i] = 0
    else:
        skip_channels=[0,0,0,0]

这段代码的逻辑是根据self.config.n_skip的值来设置skip_channels列表。self.config.n_skip可能代表模型中跳过(skip)连接的数量,这是一种在卷积神经网络(如UNet、ResNet等)中常见的结构,用于将较浅层的特征直接传递给深层或解码阶段以保留更多细节信息。

  1. 如果self.config.n_skip不等于0,也就是说存在跳过连接,则获取配置中的skip_channels列表。

    • 遍历一个从3到(3-self.config.n_skip)的范围(包括3但不包括3-self.config.n_skip),即将索引从3开始往前遍历至第self.config.n_skip层之后的层。
    • 对于这个范围内的索引i,将skip_channels列表中相应位置的元素赋值为0,意味着这些层级不使用跳过连接。
  2. 如果self.config.n_skip等于0,表示不使用任何跳过连接,则直接将skip_channels列表初始化为 [0,0,0,0],即所有层级都不使用跳过连接。
    总结来说,这段代码是在根据配置调整模型中各层是否采用跳过连接,若n_skip较大,则越靠近输出层的几层不使用跳过连接;若n_skip为0,则所有层均不使用跳过连接。

    def forward(self, hidden_states, features=None):
        B, n_patch, hidden = hidden_states.size()  # 获取隐藏状态的尺寸
        h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))  # 计算每个维度的补丁数
        x = hidden_states.permute(0, 2, 1)  # 调整隐藏状态的维度
        x = x.contiguous().view(B, hidden, h, w)  # 将隐藏状态重塑为B x hidden x h x w
        x = self.conv_more(x)  # 通过额外的卷积层

        for i, decoder_block in enumerate(self.blocks):  # 遍历解码器块
            if features is not None:
                skip = features[i] if (i < self.config.n_skip) else None  # 获取跳跃连接的特征
            else:
                skip = None
            x = decoder_block(x, skip=skip)  # 将解码器块应用于x,并添加跳跃连接特征

        return x  # 返回解码器的输出
  • 20
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值