小样本学习下的Transformer:基于谱聚类层和标签代理学习

59ca205fa2836b6acfeda865588836d7.gif

©作者 | 知乎用户Alicia

研究方向 | 小样本学习

662242180f97bf2e03d9bcb4a576dba4.png

论文标题:

Attribute Surrogates Learning and Spectral Tokens Pooling in Transformers for few-shot learning

论文链接:

https://arxiv.org/pdf/2203.09064.pdf

代码链接:

https://github.com/StomachCold/HCTransformers


1d7ea58a0d8dd1fc1837c1073909dd54.png

Abstract & Conclusion

1.1 Abstract

目的:通过提高数据的有效性来解决 Transformer “数据饥渴”的问题,从而能够使用 Transformer 解决小样本问题。

创新点(如何提高数据的有效性):主要有两个工作,一个是提出了 token 谱聚类层,作用是获取并利用图像内部结构来降低目标前景区域和背景噪声区域分界的模糊性(也就是让目标和噪声的边界更清晰);第二个则是提出了一个标签代理学习方案,作用是充分利用图像-标签对的视觉信息而不是单一的标签视觉概念。

模型:HCTransformer,基于自监督学习框架 DINO,使用三组 Transformer 进行串联,两个 transformer 之间有一个 token 下采样层,同时采取标签代理学习方案来优化学习参数。

结果:在 4 个小样本的基准数据集上做了测试,在 5-way 1-shot 和 5-way 5-shot 上都以明显的优势超过了 DINO 基准模型,同时也超过了目前的 SOTA 模型。


1.2 conclusion

HCTransformer 旨在提高数据的有效性来解决小样本图像分类问题。尽管视觉 transformer 具有数据饥渴的特性,但我们在小样本学习问题上使用 transformer 取得了很好的结果。本文方法引入了一种隐性监督传播技术,通过可学习的标签代理隐性监督参数学习。我们提出了一个集成 patch token 的方案,它可以与 [cls] token 互补。此外,使用 token 谱聚类嵌入 transformer 的 token 之间的对象/场景布局和语义关系。HCTransfomer 不仅比 DINO 基准模型表现更好,而且在四个流行的基准数据集上以明显的优势胜过目前的 SOTA 模型。

1156cf8c616eb40d69bf82f850e3a4bf.png

Introduction & Related work

2.1 Introduction

ViT 在 CV 的很多任务上都表现很好,然而一些文献指出当训练数据不足时 ViT 模型性能就变得很低。本文研究了这样一个问题:当有标签的训练数据非常有限的情况下,使用 ViT 到底行得通吗?

既然有标注的图片有限,那么如何充分挖掘图像隐藏的信息,对其中的视觉概念有一个完整的描述就变得很重要。

本文希望提高数据的有效性以便在小样本图像分类上使用 ViT 模型。具体来说就是使用三组串联的 ViT 作为元特征提取器,每一组 ViT 在不同语义层面上对图像各区域的依赖性进行建模。每两个 ViT 之间有一个 token 谱聚类层进行下采样,将图像分割的语义信息传递给下一个 ViT。此外引入一个潜在的属性代理学习方案来学习视觉信息的稳定表征。

2.2 Related work

关于 tokens pooling:论文作者认为 ViT 模型关注的是整张图像的全局信息而忽略了局部区域的依赖关系,这才导致了 ViT “数据饥渴”的性质。后续有很多工作在解决这个问题。有一些方法通过将一个窗口内的 token 直接进行合并来减少 token 的数量,而 HCT 则允许 token 根据空间布局和语义相似性与相邻的 token 自适应合并。

关于监督 token:ViT 加入了一个 [cls] token 全局性地汇总所有 patch 的整体信息,只有这个 token 直接接收视觉信息。但是其他的 token 仍然拥有表达特殊模式的能力并可能协助最终的预测。一些文献提出移除 [cls] token,通过平均池化操作整合所有 token 的信息从而建立一个全局的 token。

LVViT 同时使用 [cls] token 和其他 token 的可能性,它利用 token 标记问题来重新定义分类任务。So-ViT 将二节协方差池化应用在视觉 token 上并与 [cls] 结合起来进行最终分类。本文方法与这两种方法直觉上相似,但是有着明显的区别。我们利用注意力机制挑选并最大限度地利用重要的 token。此外,我们假设融合的 token 与 [cls] token 并不共享特征空间,而是分别在各自的特征空间中对它们进行监督。

8dcfcea84aa99e5defa139c76c587910.png

Model & Method

766c714e9eafd6e708756099caabb1cd.png

▲ figure 1. model architecture


3.1 preliminary

由于 HCT 是基于 DINO 架构的,所以首先来看看 DINO 的架构以及其中的 Multi- crop 策略。图 2 是自监督学习框架 DINO 的简单表示。 和 是输入图像的不同视图。teacher 网络和 student 网络架构相同但是参数不同。其中 student 网络的参数通过反向传播进行更新,而 teacher 网络的参数并不通过反向传播更新,(sg 表示 stop gradient,截断梯度的传播了)而是由 student 网络的参数通过指数移动平均得到。

0a5d8d4c7aa2c2fa6d40ec1a1ed0802e.png

▲ figure 2. illustrate DINO in the case of one single pair of views for simplicity

multi- crop:对于输入的每一张图片,通过增强和裁剪获取两个 global views , 和 m 个 local views,将这些 views 的集合记作 V。注意,这里的 teacher 网络只接收 global views,而 student 网络接收 local views 以及 teacher 网络没有接收的 global views。则 DINO 的目标函数为:

c76b697b2f641f3dd0f01d7886dab12e.png

3.2 Attribute Surrogates Learning

假设共有 C 个类,对于每一类标签 y,学习一个语义标签代理 :,标签代理描述子 。(这里就是说我把标签变成一个可学习的向量即标签代理)训练过程中通过代理来监督 student 网络的参数学习,同时代理的参数也需要更新。假设 student 网络的的目标函数为 ,其参数 以及关联的标签代理的更新公式为:

9d44910132d5694979b1c2905dd474f0.png

为了充分利用 transformer 的优势,我们对于 patch token 和 class token 都要学习与一标签代理。

3.3 Supervise the Class Token

本文使用代理损失来监督每个类的概率分布,类 y 的代理描述子 ,使用 softmax 函数获得 的分布。(这里论文就是直接说 要取一个很大的数,他们将其设为 8192,没有具体说为什么设这个数)。

ed4bc030e4487c6cf4c1bf0d503aa9b6.png

这里的损失函数只考虑 global views 而不考虑 local views,作者说 local views 在更新类中心时由于有图像信息的丢失可能会产生负面影响。

3.4 Supervise Patch Tokens

Problem:patch token 没有标注所以难以监督。

本文做法:整合 patch tokens 的特征 作为图像的全局描述子,公式为

6b6624bf7be3ac3659ed5be596123df3.png

其中 表示 [cls] token 和 patch tokens 的相似性矩阵,其中每一个元素就是 [cls] token 和每一个 patch token 的相似性。 表示 patch token 的特征。

patch token的损失函数如下,其中 , 。

3336e89ccf6e11263e5b1f72a9d784fc.png

3.5 Spectral Tokens Pooling

本文提出一个基于谱聚类的 token 池化方法。对于 ViT 中的 N 个 patch,检索这 N 个 patch 的注意力矩阵 ,再计算一个 token 的临界矩阵 来反映邻接关系。每个中心 token 周围有 8 个邻接的 token。利用以下公式计算得到一个对称矩阵 S:

8b1e06a90870807adfa61fba1749a950.png

对对称矩阵 S 的每一行做 Softmax 操作,得到最终的邻接权重矩阵 ,作为 token 谱聚类算法的输入。Token 谱聚类对 patch token 进行下采样,得到新的 patch token。


634c6c9bb13e9a29edbcf61450fd2c04.png

▲ figure 3. visualized tokens pooling results

4a303ec59a8875c2ff958b413b676b38.png

Experiments


4.1 Training Strategy of HCTransformers


每个 transformer 输出的 token 数量分别为 784,392,196。

没有采用端到端的训练,而是将训练分成两个阶段。

第1阶段:按照 DINO 的设置训练训练第一组 transformer,目标函数是: 

c24e851023cabd67910b2f5ec45377be.png

第2阶段:冻结第一阶段两个 transformer 的参数,训练后两组 transformer,目标函数是:

450136dc787a0292650fec76999fb82f.png

实验结果这里就不多说了,总之就是效果很好,超过 SOTA。最后论文还给出了特征的可视化,但是论文正文似乎没有提及,就是放了个图。写完才发现好像上传不了图片,可对着论文正文看。

8c6eac5ed91c2217a31976d67a2c81cb.png

▲ figure 4. visualized features on train sets and val sets on mini ImageNet

更多阅读

f0af9398de7ce998a9ea85a162bd12d7.png

d6a602e36cf5bd7231c7818c2ecc3de1.png

66dc8551d1abafa7dbb5e14cba7c083a.png

6da6a7e9b0f1e61e8326e284f28f6e92.gif

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

e8a402c2845bc1d2a4ea304c49d2feea.png

△长按添加PaperWeekly小编

🔍

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

·

d174078c1a445c48f8b929caa782756e.png

### 基于 Transformer 的聚类算法实现方式 基于 Transformer 的聚类算法是一种结合了传统聚类方法(如 K-means)现代深度学习技术的方法。它利用 Transformer 中的自注意力机制提取特征,并通过优化目标函数完成数据点到簇中心的映射。 以下是基于 Transformer K-means 聚类的一个具体实现思路: #### 特征提取阶段 首先,使用预训练好的 Transformer 模型(例如 BERT 或 ViT),将输入数据转化为高维向量表示。这些向量捕捉到了数据的主要语义或空间特性。 ```python from transformers import BertTokenizer, BertModel import torch def extract_features(texts): tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained('bert-base-uncased') inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): outputs = model(**inputs) # 使用 [CLS] token 的隐藏状态作为句子嵌入 embeddings = outputs.last_hidden_state[:, 0, :].numpy() return embeddings ``` 此代码片段展示了如何使用 Hugging Face 提供的 `BertModel` 来获取文本数据的句向量[^1]。 --- #### 聚类阶段 接着,在获得的数据嵌入基础上应用 K-means 算法或其他聚类方法。为了增强效果,还可以引入交叉注意力机制动态调整簇中心的位置[^2]。 ```python from sklearn.cluster import KMeans def cluster_embeddings(embeddings, num_clusters): kmeans = KMeans(n_clusters=num_clusters, random_state=42) labels = kmeans.fit_predict(embeddings) centers = kmeans.cluster_centers_ return labels, centers ``` 上述代码实现了标准的 K-means 方法用于对提取出的特征进行划分[^3]。 --- #### 结合 Transformer 的改进方案 对于更复杂的场景,可以设计一种端到端的学习框架,其中包含以下组件: 1. **编码器模块**:负责生成高质量的样本表征; 2. **聚类**:采用软分配策略或者硬分配策略决定每个实例所属类别; 3. **损失函数**:联合考虑重建误差与分布一致性约束项。 下面是一个简化版的例子展示如何构建这样的架构并执行训练过程: ```python class ClusterTransformer(torch.nn.Module): def __init__(self, input_dim, hidden_dim, num_heads, num_clusters): super(ClusterTransformer, self).__init__() self.transformer_encoder = torch.nn.TransformerEncoder( encoder_layer=torch.nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads), num_layers=2 ) self.fc = torch.nn.Linear(input_dim, num_clusters) def forward(self, x): out = self.transformer_encoder(x.unsqueeze(1)).squeeze(1) logits = self.fc(out) probabilities = torch.softmax(logits, dim=-1) return probabilities def train(model, data_loader, optimizer, device): criterion = torch.nn.CrossEntropyLoss() model.train() total_loss = 0 for batch in data_loader: features, _ = batch features = features.to(device) predictions = model(features) pseudo_labels = torch.argmax(predictions.detach(), dim=-1).to(device) loss = criterion(predictions, pseudo_labels) total_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() avg_loss = total_loss / len(data_loader) return avg_loss ``` 在此示例中,我们定义了一个小型的 Transformer 编码器配合全连接输出概率分布,从而模拟聚类行为。 --- ### 总结 以上介绍了两种主要途径来达成基于 Transformer 的聚类任务——一是先单独运用 Transformer 获取良好表达再交给经典算法;二是创建融合两者的定制化神经网络结构以便更好地适应特定需求。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值