🏡作者主页:点击!
🤖编程探索专栏:点击!
⏰️创作时间:2024年12月2日15点34分
神秘男子影,
秘而不宣藏。
泣意深不见,
男子自持重,
子夜独自沉。
概述
SAM-Adapter 的核心思想是通过引入轻量级适配器,将任务特定知识注入到冻结的 SAM 模型中,以增强其在下游任务中的适应能力。适配器的设计简洁高效,通过灵活的任务知识输入,提升了模型的性能与泛化能力,特别是在数据稀缺场景下表现突出。
- 该文章分析了SAM作为基础模型的局限性,并提出如何利用SAM服务于下游任务的问题;
- 其提出的SAM-Adapter,创新性地整合任务特定知识与大模型地通用知识,灵活适应多种任务。
模型详述
1. 使用SAM作为骨干网络
- 目标:SAM-Adapter的目标是灵活利用SAM预训练模型中的知识;
- 骨干架构:使用SAM的图像编码器(基于ViT-H/16)作为骨干网络,同时保持其预训练权重不变;使用SAM的掩码解码器(包括修改后的Transformer解码器和动态掩码预测头),在预训练权重的基础上微调;
- 任务特定知识引入:通过适配器将任务特定知识 FiFi 注入网络,利用Prompts技术提升莫i下在下游任务中的泛化能力。
2. 输入任务特定知识
任务特定知识FiFi可以根据具体任务灵活设计,形式多样。其可以是从数据集中提取的特征(如纹理或频率信息),也可以是手工设计的规则,以及多种信息的组合形式:
其中FjFj为某种类型的知识或特征,wjwj为可调节的权重(用于控制组合强度)。
3. Adapters
结构:由两个多层感知器(MLP)和一个激活函数(GELU)组成:
Pi=MLPup(GELU(MLPtuneiFi))Pi=MLPup(GELU(MLPtuneiFi))
其中,MLPtuneiMLPtunei是线性层,用于为每个适配器生成任务特定的提示(prompts); MLPupMLPup 是一个共享的上投影层,用于调整Transformer特征的维度;GELUGELU是激活函数。 PiPi是输出的提示,附加到SAM模型的每一层Transformer中。
在该项目的代码实现中,是这样实现adpater的功能的:
class PromptGenerator(nn.Module):
def __init__(self, ...):
...
self.shared_mlp = nn.Linear(self.embed_dim//self.scale_factor, self.embed_dim)
self.embedding_generator = nn.Linear(self.embed_dim, self.embed_dim//self.scale_factor)
for i in range(self.depth):
lightweight_mlp = nn.Sequential(
nn.Linear(self.embed_dim//self.scale_factor, self.embed_dim//self.scale_factor),
nn.GELU()
)
setattr(self, 'lightweight_mlp_{}'.format(str(i)), lightweight_mlp)
self.prompt_generator = PatchEmbed2(img_size=img_size,
patch_size=patch_size, in_chans=3,
embed_dim=self.embed_dim//self.scale_factor)
...
def init_embeddings(self, x):
N, C, H, W = x.permute(0, 3, 1, 2).shape
x = x.reshape(N, C, H*W).permute(0, 2, 1)
return self.embedding_generator(x)
def init_handcrafted(self, x):
x = self.fft(x, self.freq_nums)
return self.prompt_generator(x)
def get_prompt(self, handcrafted_feature, embedding_feature):
N, C, H, W = handcrafted_feature.shape
handcrafted_feature = handcrafted_feature.view(N, C, H*W).permute(0, 2, 1)
prompts = []
for i in range(self.depth):
lightweight_mlp = getattr(self, 'lightweight_mlp_{}'.format(str(i)))
# prompt = proj_prompt(prompt)
prompt = lightweight_mlp(handcrafted_feature + embedding_feature)
prompts.append(self.shared_mlp(prompt))
return prompts
...
(1)这里只摘取了可以显示其大概思路的部分进行展示,至于细节则请参考该项目的具体实现;
(2)self.prompt_generator.init_embeddings和self.prompt_generator.init_handcrafted的实现均很简单,分别是线性层和卷积层;
(3)在获取handcrafted_feature时,运用了傅里叶变换,然后提取高频信息,对应的是原图像中的边缘、纹理等信息;
(4) embedding_feature更偏向图像的全局语义,适合提供通用背景信息, 而handcrafted_feature偏向图像的局部高频特征,适合突出任务关键细节。 两者互补,使得生成的prompts同时具有全局视角和局部任务适应性。
实验
数据集介绍
本文复现使用的是COD10K数据集,其在伪装目标检测(COD)领域具有重要地位,包含10,000张图像,涵盖78个类别(69个伪装类别,9个非伪装类别)。这些图像来自多种自然场景,包括5,066张伪装目标图像、3,000张背景图像和1,934张非伪装目标图像。数据集提供高分辨率图像和精细标注信息,可支持目标检测、分割和边缘检测等任务。其丰富的多样性和高质量标注使其成为伪装目标检测领域的重要研究资源。
复现流程
如图展示了该模型在COD10K的测试集中随机选取的两个样本上的预测效果。从左到右依次为原图、真实标签和模型预测结果。
- 下载附件中的项目代码、数据集和权重文件并放置在相应路径下。
数据集(cod10k)和相应的权重文件我已经准备好,网盘链接也放在了附件当中。 - 训练模型
python train_single.py --config [config_file_path]
train_single.py是我增加的适用于非分布式环境(单GPU)的训练脚本。
- 推理
python test.py --config [config_file_path] --model [model_path]
环境配置
python 3.9
torch2.2.1+cu121
A800
成功的路上没有捷径,只有不断的努力与坚持。如果你和我一样,坚信努力会带来回报,请关注我,点个赞,一起迎接更加美好的明天!你的支持是我继续前行的动力!"
"每一次创作都是一次学习的过程,文章中若有不足之处,还请大家多多包容。你的关注和点赞是对我最大的支持,也欢迎大家提出宝贵的意见和建议,让我不断进步。"
神秘泣男子