Matryoshka Representation Learning技术小结

文章介绍了MatryoshkaRepresentationLearning(MRL),一种通过一次训练生成不同维度特征表示的简单而有效的方法。MRL通过共享权重的FC和Classifier实现多尺度表征,适用于工程场景。实验结果展示了在ImageNet上的良好性能和插值性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

info
paperhttps://arxiv.org/abs/2205.13147
codehttps://github.com/RAIVNLab/MRL
org华盛顿大学、Google、哈弗大学
个人博客位置http://www.myhz0606.com/article/mrl

Motivation

我们平时做retrieval相关的工作,很多时候需要根据业务场景和计算资源对向量进行降维。受限开发周期,我们往往不会通过重新训练特征提取模型来调整向量维度,而是用PCA等方法来实现。但是当降维的scale较大时,PCA等方法的效果较差。Matryoshka Representation Learning (MRL)这篇paper介绍了一个很简单但有效的方法能实现一次训练,获取不同维度的表征提取。下面来看它具体是怎么做的吧。

Method

文中只描述MRL最核心的部分,详细介绍请看原论文。

我们以一个图像分类任务为例,其pipeline如下。图片首先通过一个Feature extractor提取特征,flatten后用一个FC来映射到表征空间,再接入一个classifier(也是个全连接层)得到该图片在类别上的概率分布。用这个方法训练,一次训练我们只能得到一种维度的图片表征(如图中是2048维)

在这里插入图片描述

为了一次训练获得不同维度的图片表征,最简单粗暴的方法就是我们可以用多个FC及对应的Classifier进行联合训练。这无疑是有效的,但由于FC和classifier多了,模型会大一些。

在这里插入图片描述

MRL对上面做了一个优化,它能通过一组FC和Classifier实现多种尺度的特征训练。pipeline如下图所示(图中同个颜色表示共享权重)。MRL实现的核心就是:对同一组FC和Classifier进行分片,从而实现不同维度的表征训练。

论文公式中的 F ( x i ; θ F ) F(x_i; \theta_{F}) F(xi;θF)是我图中的Feature_extractor + FC

min ⁡ { W ( m ) } m ∈ M ,   θ F 1 N ∑ i ∈ [ N ] ∑ m ∈ M c m ⋅ L ( W ( m ) ⋅ F ( x i ; θ F ) 1 : m   ;   y i )    , \min _ { \{ { \boldsymbol W } ^ { ( m ) } \} _ { m \in { \mathcal M } } , \, \theta _ { F } } \frac { 1 } { N } \sum _ { i \in [ N ] } \sum _ { m \in { \mathcal M } } c _ { m } \cdot { \mathcal L } ( { \boldsymbol W } ^ { ( m ) } \cdot F ( x _ { i } ; \theta _ { F } ) _ { 1 : m } \, ; \, y _ { i } ) \; , {W(m)}mM,θFminN1i[N]mMcmL(W(m)F(xi;θF)1:m;yi),

在这里插入图片描述

MRL的实现源码如下图所示:

class MRL_Linear_Layer(nn.Module):
	def __init__(self, nesting_list: List, num_classes=1000, efficient=False, **kwargs):
		super(MRL_Linear_Layer, self).__init__()
		self.nesting_list = nesting_list
		self.num_classes = num_classes # Number of classes for classification
		self.efficient = efficient
		if self.efficient:
			setattr(self, f"nesting_classifier_{0}", nn.Linear(nesting_list[-1], self.num_classes, **kwargs))		
		else:	
			for i, num_feat in enumerate(self.nesting_list):
				setattr(self, f"nesting_classifier_{i}", nn.Linear(num_feat, self.num_classes, **kwargs))	

	def reset_parameters(self):
		if self.efficient:
			self.nesting_classifier_0.reset_parameters()
		else:
			for i in range(len(self.nesting_list)):
				getattr(self, f"nesting_classifier_{i}").reset_parameters()

	def forward(self, x):
		nesting_logits = ()
		for i, num_feat in enumerate(self.nesting_list):
			if self.efficient:
				if self.nesting_classifier_0.bias is None:
					nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()), )
				else:
					nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()) + self.nesting_classifier_0.bias, )
			else:
				nesting_logits +=  (getattr(self, f"nesting_classifier_{i}")(x[:, :num_feat]),)

		return nesting_logits

Result

该图对比了MRL不同维度的表征在imagenet1K上linear classification和1-NN的准确率。

在这里插入图片描述

下图给出了scale model和dataset时MRL依旧有效,并且MRL提取的表征具备良好的插值性能。

在这里插入图片描述

更多实验结果见原论文。

小结

这篇文章虽然idea很简单,但很适合工程应用。

参考文献

Matryoshka Representation Learning

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值