论文介绍
题目:DS-TransUNet: Dual Swin Transformer U-Net for Medical Image Segmentation
论文地址:https://arxiv.org/abs/2106.06716
创新点
-
双尺度编码器:论文提出了一种基于双分支的编码器架构,使用不同尺度的图像块(patch)进行特征提取。这种双尺度方法可以同时捕捉粗粒度和细粒度的特征,从而提升了语义分割的效果。
-
Transformer交互融合模块(TIF):提出了一个新颖的TIF模块,通过Transformer的自注意力机制,有效地融合了来自双尺度编码器的多尺度特征表示。这种融合方式建立了特征间的全局依赖关系,从而保证了多尺度特征的语义一致性。
-
在解码器中引入Swin Transformer:创新性地在U-Net解码器中使用了Swin Transformer模块,不仅在下采样阶段建模了长程依赖,还在上采样阶段进一步提升了上下文信息的利用效率。
-
全面的实验验证:通过四个典型的医学图像分割任务(如息肉分割、皮肤病变分割等)的实验,展示了DS-TransUNet在分割质量上优于现有的最先进方法,尤其是在息肉分割任务中表现突出。
方法
整体架构
DS-TransUNet是一种基于双分支编码器的U型网络结构,融合了Swin Transformer的长程依赖建模能力。它通过双尺度编码器提取粗粒度和细粒度特征,利用Transformer交互融合模块(TIF)实现多尺度特征的全局交互,在解码器中进一步引入Swin Transformer块建模全局上下文,从而实现高效的医学图像分割。这种架构能够捕捉丰富的多尺度信息,并在多个分割任务中表现出色。
1. 双分支编码器(Dual-Branch Encoder)
-
双尺度特征提取:输入的医学图像被分割为两种不同尺度的非重叠图像块(patch),分别通过两个独立的分支处理:
-
主分支:处理细粒度图像块(较小尺寸的patch),提取细粒度特征。
-
辅分支:处理粗粒度图像块(较大尺寸的patch),提取粗粒度特征。
-
特征提取器:每个分支使用分层的 Swin Transformer 作为编码器,对图像块进行特征表示学习,并通过多个阶段逐步提取高层次特征。
2. Transformer交互融合模块(Transformer Interactive Fusion, TIF)
-
特征融合:通过标准Transformer块的自注意力机制,融合双分支(粗粒度和细粒度)的特征表示。
-
全局依赖建模:TIF模块能够捕捉不同尺度特征之间的全局依赖关系,并在特征间实现高效交互。
3. 解码器(Decoder)
-
上采样与跳跃连接:解码器采用逐层上采样的方式,并利用编码器对应层的特征通过跳跃连接(Skip Connections)来恢复原始分辨率。
-
引入Swin Transformer块:在每个解码阶段加入Swin Transformer块,以建模长程依赖和全局上下文信息,从而提升解码器的表现。
-
最终输出:融合后的特征被逐步恢复为与输入图像相同的分辨率,生成像素级的分割结果。
即插即用模块作用
TIF 作为一个即插即用模块:
-
多尺度特征融合:TIF模块利用自注意力机制,在不同尺度的特征之间建立全局交互,提升多尺度特征的融合效果,保证语义一致性。
-
增强全局上下文信息:通过全局依赖建模,TIF模块能够在特征中注入丰富的上下文信息,提高目标分割的准确性和鲁棒性。
-
提升分割细节表现:对于边界复杂或细粒度分割任务,TIF模块能有效提升目标边界的分割质量,减少边界模糊现象。
-
即插即用的灵活性:TIF模块可以作为现有深度学习模型(如U-Net、FPN)的插件模块,无需对整体结构进行大幅修改,即可显著提升模型性能。
消融实验结果
表 VIII 展示了不同模型配置(Base Model、Swin U-Net、Swin Decoder、Multi-Scale SD和DS-TransUNet)在息肉分割任务上的性能对比。实验验证了Swin Transformer作为编码器的有效性、Swin Decoder的长程依赖建模能力,以及TIF模块在多尺度特征融合中的关键作用。最终模型DS-TransUNet在所有数据集上的分割性能均优于其他配置。
图 4 展示了DS-TransUNet在息肉分割任务中(包括Kvasir、CVC-ClinicDB及多个数据集)的定性分割结果。与其他模型相比,DS-TransUNet表现出更强的边界捕捉能力,特别是在处理模糊、颜色与背景相近或边缘复杂的息肉时,其分割结果更接近真实边界。
即插即用模块
import torch
from torch import nn, einsum
from einops import rearrange
#论文:DS-TransUNet: Dual Swin Transformer U-Net for Medical Image Segmentation
#论文地址:https://arxiv.org/abs/2106.06716
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _ = x.shape
h = self.heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class CrossAttention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.to_k = nn.Linear(dim, inner_dim , bias=False)
self.to_v = nn.Linear(dim, inner_dim , bias = False)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x_qkv):
b, n, _ = x_qkv.shape
h = self.heads
k = self.to_k(x_qkv)
k = rearrange(k, 'b n (h d) -> b h n d', h = h)
v = self.to_v(x_qkv)
v = rearrange(v, 'b n (h d) -> b h n d', h = h)
q = self.to_q(x_qkv[:, 0].unsqueeze(1))
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class TIF(nn.Module):
def __init__(self, dim_s, dim_l):
super().__init__()
self.transformer_s = Transformer(dim=dim_s, depth=1, heads=3, dim_head=32, mlp_dim=128)
self.transformer_l = Transformer(dim=dim_l, depth=1, heads=1, dim_head=64, mlp_dim=256)
self.norm_s = nn.LayerNorm(dim_s)
self.norm_l = nn.LayerNorm(dim_l)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.linear_s = nn.Linear(dim_s, dim_l)
self.linear_l = nn.Linear(dim_l, dim_s)
def forward(self, e, r):
b_e, c_e, h_e, w_e = e.shape
e = e.reshape(b_e, c_e, -1).permute(0, 2, 1)
b_r, c_r, h_r, w_r = r.shape
r = r.reshape(b_r, c_r, -1).permute(0, 2, 1)
e_t = torch.flatten(self.avgpool(self.norm_l(e).transpose(1, 2)), 1)
r_t = torch.flatten(self.avgpool(self.norm_s(r).transpose(1, 2)), 1)
e_t = self.linear_l(e_t).unsqueeze(1)
r_t = self.linear_s(r_t).unsqueeze(1)
r = self.transformer_s(torch.cat([e_t, r], dim=1))[:, 1:, :]
e = self.transformer_l(torch.cat([r_t, e], dim=1))[:, 1:, :]
e = e.permute(0, 2, 1).reshape(b_e, c_e, h_e, w_e)
r = r.permute(0, 2, 1).reshape(b_r, c_r, h_r, w_r)
return e + r
if __name__ == '__main__':
model = TIF(dim_s=64, dim_l=64)
input1 = torch.randn(1, 64, 64, 64) # 例如来自小尺度特征的图像
input2 = torch.randn(1, 64, 64, 64) # 例如来自大尺度特征的图像
# 前向传播获取输出
output = model(input1, input2)
# 打印输入和输出的形状
print(input1.size())
print(input2.size()) print(output.size())
如何学习大模型 AI ?
由于新岗位的生产效率,要优于被取代岗位的生产效率,所以实际上整个社会的生产效率是提升的。
但是具体到个人,只能说是:
“最先掌握AI的人,将会比较晚掌握AI的人有竞争优势”。
这句话,放在计算机、互联网、移动互联网的开局时期,都是一样的道理。
我在一线互联网企业工作十余年里,指导过不少同行后辈。帮助很多人得到了学习和成长。
我意识到有很多经验和知识值得分享给大家,也可以通过我们的能力和经验解答大家在人工智能学习中的很多困惑,所以在工作繁忙的情况下还是坚持各种整理和分享。但苦于知识传播途径有限,很多互联网行业朋友无法获得正确的资料得到学习提升,故此将并将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。
第一阶段(10天):初阶应用
该阶段让大家对大模型 AI有一个最前沿的认识,对大模型 AI 的理解超过 95% 的人,可以在相关讨论时发表高级、不跟风、又接地气的见解,别人只会和 AI 聊天,而你能调教 AI,并能用代码将大模型和业务衔接。
- 大模型 AI 能干什么?
- 大模型是怎样获得「智能」的?
- 用好 AI 的核心心法
- 大模型应用业务架构
- 大模型应用技术架构
- 代码示例:向 GPT-3.5 灌入新知识
- 提示工程的意义和核心思想
- Prompt 典型构成
- 指令调优方法论
- 思维链和思维树
- Prompt 攻击和防范
- …
第二阶段(30天):高阶应用
该阶段我们正式进入大模型 AI 进阶实战学习,学会构造私有知识库,扩展 AI 的能力。快速开发一个完整的基于 agent 对话机器人。掌握功能最强的大模型开发框架,抓住最新的技术进展,适合 Python 和 JavaScript 程序员。
- 为什么要做 RAG
- 搭建一个简单的 ChatPDF
- 检索的基础概念
- 什么是向量表示(Embeddings)
- 向量数据库与向量检索
- 基于向量检索的 RAG
- 搭建 RAG 系统的扩展知识
- 混合检索与 RAG-Fusion 简介
- 向量模型本地部署
- …
第三阶段(30天):模型训练
恭喜你,如果学到这里,你基本可以找到一份大模型 AI相关的工作,自己也能训练 GPT 了!通过微调,训练自己的垂直大模型,能独立训练开源多模态大模型,掌握更多技术方案。
到此为止,大概2个月的时间。你已经成为了一名“AI小子”。那么你还想往下探索吗?
- 为什么要做 RAG
- 什么是模型
- 什么是模型训练
- 求解器 & 损失函数简介
- 小实验2:手写一个简单的神经网络并训练它
- 什么是训练/预训练/微调/轻量化微调
- Transformer结构简介
- 轻量化微调
- 实验数据集的构建
- …
第四阶段(20天):商业闭环
对全球大模型从性能、吞吐量、成本等方面有一定的认知,可以在云端和本地等多种环境下部署大模型,找到适合自己的项目/创业方向,做一名被 AI 武装的产品经理。
- 硬件选型
- 带你了解全球大模型
- 使用国产大模型服务
- 搭建 OpenAI 代理
- 热身:基于阿里云 PAI 部署 Stable Diffusion
- 在本地计算机运行大模型
- 大模型的私有化部署
- 基于 vLLM 部署大模型
- 案例:如何优雅地在阿里云私有部署开源大模型
- 部署一套开源 LLM 项目
- 内容安全
- 互联网信息服务算法备案
- …
学习是一个过程,只要学习就会有挑战。天道酬勤,你越努力,就会成为越优秀的自己。
如果你能在15天内完成所有的任务,那你堪称天才。然而,如果你能完成 60-70% 的内容,你就已经开始具备成为一名大模型 AI 的正确特征了。