基于Pytorch实现的度量学习方法
- 开源代码:pytorch-metric-learning
- 官网文档:PyTorch Metric Learning官方文档
- 基于pytorch-metric-learning实现的度量学习模板代码:pytorch-metric-learning-template
度量学习相关的损失函数介绍:
- 度量学习DML之Contrastive Loss及其变种
- 度量学习DML之Triplet Loss
- 度量学习DML之Lifted Structure Loss
- 度量学习DML之Circle Loss
- 度量学习DML之Cross-Batch Memory
- 度量学习DML之MoCO
基于度量学习方法实现音乐特征匹配的系列文章
- 从零搭建音乐识别系统(一)整体功能介绍
- 从零搭建音乐识别系统(二)音频特征提取
- 从零搭建音乐识别系统(三)音乐分类模型
- 从零搭建音乐识别系统(四)embedding特征提取模型
- 从零搭建音乐识别系统(五)embedding特征提取模型验证
1、整体总览
pytorch-metric-learning包含9个模块,每个模块都可以在现有的代码库中独立使用,或者组合在一起作为一个完整的训练、测试工作流。
1.1、自定义度量学习损失函数
损失函数可以使用距离、规约方法和正则化方法来进行定制化。在下面的图表中,miner在批量中找到难训练的样本对的索引,这些索引被索引到距离矩阵。
2、距离度量 Distances
Distance类用来计算成对的embedding之间的距离或者相似度。以三元组方法TripletMarginLoss为例,三元组的表示为<anchor, positive, negative>。其中anchor和positive构成正样本对,anchor和negative构成负样本对。
- 如果使用的是距离度量方法,就是拉近anchor和positive的距离,推远anchor和negative的距离,公式表示为 $ [d_{ap} - d_{an} + margin]_{+} $。
### TripletMarginLoss with squared L2 distance ###
from pytorch_metric_learning.distances import LpDistance
loss_func = TripletMarginLoss(margin=0.2, distance=LpDistance(power=2))
### TripletMarginLoss with unnormalized L1 distance ###
loss_func = TripletMarginLoss(margin=0.2, distance=LpDistance(normalize_embeddings=False, p=1))
### TripletMarginLoss with signal-to-noise ratio###
from pytorch_metric_learning.distances import SNRDistance
loss_func = TripletMarginLoss(margin=0.2, distance=SNRDistance())
- 如果使用的是相似度度量方法,就是增大anchor和positive的相似性,降低anchor和negative的相似性,公式表示为 $ [s_{an} - s_{ap} + margin]_{+} $。
### TripletMarginLoss with cosine similarity##
from pytorch_metric_learning.distances import CosineSimilarity
loss_func = TripletMarginLoss(margin=0.2, distance=CosineSimilarity())
备注: 所有的 losses、miners和regularizers都可以接受distance参数,但是有些方法有distance类型的限制,比如只能使用 CosineSimilarity 或者 DotProductSimilarity等,具体可以参考losses 页面
2.1 基础距离计算类 BaseDistance
所有的embedding距离/相似度计算子类都继承并扩展自BaseDistance类。
distances.BaseDistance(collect_stats = False,
normalize_embeddings=True,
p=2,
power=1,
is_inverted=False)
- 参数解释:
- collect_stats: 如果为 True,将收集可能在实验中对分析有用的各种统计数据。如果为False,这些计算将被跳过。想让True为默认值吗? 设置全局 COLLECT_STATS 标志。
- normalize_embeddings: 如果为 True,在计算loss之前,embedding将会被归一化为模为 1。
- p: 距离范数
- power: 如果不是 1,embedding的每一个元素都会被以 mat = mat ** self.power 方式放大
- is_inverted: 应该由子类设置。如果为False,则较小的值表示靠近的embedding(距离度量相关的子类默认设置为False)。如果为True,则较大的值表示彼此相似的embedding(相似性度量相关的子类默认设置为True)
继承BaseDistance类需要实现以下两个方法:
# Must return a matrix where mat[j,k] represents the distance/similarity between query_emb[j] and ref_emb[k]
def compute_mat(self, query_emb, ref_emb):
raise NotImplementedError
# Must return a tensor where output[j] represents the distance/similarity between query_emb[j] and ref_emb[j]
def pairwise_distance(self, query_emb, ref_emb):
raise NotImplementedError
2.2 BatchedDistance
没用过,也没搞懂。
2.3 余弦相似度 CosineSimilarity
余弦相似度,当embedding做了模为 1 的归一化之后,等于点积相似度 DotProductSimilarity。
2.4 点积相似度 DotProductSimilarity
返回两个embedding向量的点积结果,当embedding做了模为 1 的归一化之后,等于余弦相似度 CosineSimilarity。
2.5 Lp范数 LpDistance
Lp范数,默认是L2范数,也就是欧几里得距离。
2.6 信噪比距离 SNRDistance
3. 损失函数 Losses
- 所有的损失函数都可以使用下面这种方式使用:
from pytorch_metric_learning import losses
loss_func = losses.SomeLoss()
loss = loss_func(embeddings, labels) # in your training for-loop
- 如果配合使用难样本挖掘方法,可以使用如下方式:
from pytorch_metric_learning import miners
miner_func = miners.SomeMiner()
loss_func = losses.SomeLoss()
miner_output = miner_func(embeddings, labels) # in your training for-loop
loss = loss_func(embeddings, labels, miner_output)
- 如果传入的是二元组或者三元组,有些损失函数可以不需要label标记:
loss = loss_func(embeddings, indices_tuple=pairs)
# it also works with ref_emb
loss = loss_func(embeddings, indices_tuple=pairs, ref_emb=ref_emb)
- 可以指定损失函数使用的规约方法reducer:
from pytorch_metric_learning import reducers
reducer = reducers.SomeReducer()
loss_func = losses.SomeLoss(reducer=reducer)
loss = loss_func(embeddings, labels) # in your training for-loop
- 计算损失的两个embedding可以来自于不同的来源(source):
loss_func = losses.SomeLoss()
# anchors will come from embeddings, positives/negatives will come from ref_emb
loss = loss_func(embeddings, labels, ref_emb=ref_emb, ref_labels=ref_labels)
- 对于分类型的损失,可以通过 get_logits 函数获取概率值:
loss_func = losses.SomeClassificationLoss()
logits = loss_func.get_logits(embeddings)
目前支持以下损失函数,不同的损失函数具有不同的原理,适用于不同的场景,拥有不太一样的配置参数。有些损失函数对于distance有特殊的要求,默认的reducer可能有区别,详情参考官网。pytorch-metric-learning代码当前支持以下损失函数:
- AngularLoss
- ArcFaceLoss
- BaseMetricLossFunction
- CircleLoss
- ContrastiveLoss
- CosFaceLoss
- CrossBatchMemory
- DynamicSoftMarginLoss
- FastAPLoss
- GenericPairLoss
- GeneralizedLiftedStructureLoss
- InstanceLoss
- HistogramLoss
- IntraPairVarianceLoss
- LargeMarginSoftmaxLoss
- LiftedStructureLoss
- ManifoldLoss
- MarginLoss
- MultiSimilarityLoss
- MultipleLosses
- NCALoss
- NormalizedSoftmaxLoss
- NPairsLoss
- NTXentLoss
- P2SGradLoss
- PNPLoss
- ProxyAnchorLoss
- ProxyNCALoss
- RankedListLoss
- SelfSupervisedLoss
- SignalToNoiseRatioContrastiveLoss
- SoftTripleLoss
- SphereFaceLoss
- SubCenterArcFaceLoss
- SupConLoss
- TripletMarginLoss
- TupletMarginLoss
- WeightRegularizerMixin
- VICRegLoss