pytorch-metric-learning度量学习工具官方文档翻译

基于Pytorch实现的度量学习方法

度量学习相关的损失函数介绍:
基于度量学习方法实现音乐特征匹配的系列文章
1、整体总览

pytorch-metric-learning包含9个模块,每个模块都可以在现有的代码库中独立使用,或者组合在一起作为一个完整的训练、测试工作流。
在这里插入图片描述

1.1、自定义度量学习损失函数

损失函数可以使用距离、规约方法和正则化方法来进行定制化。在下面的图表中,miner在批量中找到难训练的样本对的索引,这些索引被索引到距离矩阵。
在这里插入图片描述

2、距离度量 Distances

Distance类用来计算成对的embedding之间的距离或者相似度。以三元组方法TripletMarginLoss为例,三元组的表示为<anchor, positive, negative>。其中anchor和positive构成正样本对,anchor和negative构成负样本对。

  • 如果使用的是距离度量方法,就是拉近anchor和positive的距离,推远anchor和negative的距离,公式表示为 $ [d_{ap} - d_{an} + margin]_{+} $。
### TripletMarginLoss with squared L2 distance ###
from pytorch_metric_learning.distances import LpDistance
loss_func = TripletMarginLoss(margin=0.2, distance=LpDistance(power=2))

### TripletMarginLoss with unnormalized L1 distance ###
loss_func = TripletMarginLoss(margin=0.2, distance=LpDistance(normalize_embeddings=False, p=1))

### TripletMarginLoss with signal-to-noise ratio###
from pytorch_metric_learning.distances import SNRDistance
loss_func = TripletMarginLoss(margin=0.2, distance=SNRDistance())
  • 如果使用的是相似度度量方法,就是增大anchor和positive的相似性,降低anchor和negative的相似性,公式表示为 $ [s_{an} - s_{ap} + margin]_{+} $。
### TripletMarginLoss with cosine similarity##
from pytorch_metric_learning.distances import CosineSimilarity
loss_func = TripletMarginLoss(margin=0.2, distance=CosineSimilarity())

备注: 所有的 losses、miners和regularizers都可以接受distance参数,但是有些方法有distance类型的限制,比如只能使用 CosineSimilarity 或者 DotProductSimilarity等,具体可以参考losses 页面

2.1 基础距离计算类 BaseDistance

所有的embedding距离/相似度计算子类都继承并扩展自BaseDistance类。

distances.BaseDistance(collect_stats = False,
                        normalize_embeddings=True, 
                        p=2, 
                        power=1, 
                        is_inverted=False)
  • 参数解释:
    • collect_stats: 如果为 True,将收集可能在实验中对分析有用的各种统计数据。如果为False,这些计算将被跳过。想让True为默认值吗? 设置全局 COLLECT_STATS 标志。
    • normalize_embeddings: 如果为 True,在计算loss之前,embedding将会被归一化为模为 1。
    • p: 距离范数
    • power: 如果不是 1,embedding的每一个元素都会被以 mat = mat ** self.power 方式放大
    • is_inverted: 应该由子类设置。如果为False,则较小的值表示靠近的embedding(距离度量相关的子类默认设置为False)。如果为True,则较大的值表示彼此相似的embedding(相似性度量相关的子类默认设置为True)

继承BaseDistance类需要实现以下两个方法:

# Must return a matrix where mat[j,k] represents the distance/similarity between query_emb[j] and ref_emb[k]
def compute_mat(self, query_emb, ref_emb):
    raise NotImplementedError

# Must return a tensor where output[j] represents the distance/similarity between query_emb[j] and ref_emb[j]
def pairwise_distance(self, query_emb, ref_emb):
    raise NotImplementedError
2.2 BatchedDistance

没用过,也没搞懂。

2.3 余弦相似度 CosineSimilarity

余弦相似度,当embedding做了模为 1 的归一化之后,等于点积相似度 DotProductSimilarity。

2.4 点积相似度 DotProductSimilarity

返回两个embedding向量的点积结果,当embedding做了模为 1 的归一化之后,等于余弦相似度 CosineSimilarity。

2.5 Lp范数 LpDistance

Lp范数,默认是L2范数,也就是欧几里得距离。

2.6 信噪比距离 SNRDistance

3. 损失函数 Losses

  • 所有的损失函数都可以使用下面这种方式使用:
from pytorch_metric_learning import losses
loss_func = losses.SomeLoss()
loss = loss_func(embeddings, labels) # in your training for-loop
  • 如果配合使用难样本挖掘方法,可以使用如下方式:
from pytorch_metric_learning import miners
miner_func = miners.SomeMiner()
loss_func = losses.SomeLoss()
miner_output = miner_func(embeddings, labels) # in your training for-loop
loss = loss_func(embeddings, labels, miner_output)
  • 如果传入的是二元组或者三元组,有些损失函数可以不需要label标记:
loss = loss_func(embeddings, indices_tuple=pairs)
# it also works with ref_emb
loss = loss_func(embeddings, indices_tuple=pairs, ref_emb=ref_emb)
  • 可以指定损失函数使用的规约方法reducer:
from pytorch_metric_learning import reducers
reducer = reducers.SomeReducer()
loss_func = losses.SomeLoss(reducer=reducer)
loss = loss_func(embeddings, labels) # in your training for-loop
  • 计算损失的两个embedding可以来自于不同的来源(source):
loss_func = losses.SomeLoss()
# anchors will come from embeddings, positives/negatives will come from ref_emb
loss = loss_func(embeddings, labels, ref_emb=ref_emb, ref_labels=ref_labels)
  • 对于分类型的损失,可以通过 get_logits 函数获取概率值:
loss_func = losses.SomeClassificationLoss()
logits = loss_func.get_logits(embeddings)

目前支持以下损失函数,不同的损失函数具有不同的原理,适用于不同的场景,拥有不太一样的配置参数。有些损失函数对于distance有特殊的要求,默认的reducer可能有区别,详情参考官网。pytorch-metric-learning代码当前支持以下损失函数:

  • AngularLoss
  • ArcFaceLoss
  • BaseMetricLossFunction
  • CircleLoss
  • ContrastiveLoss
  • CosFaceLoss
  • CrossBatchMemory
  • DynamicSoftMarginLoss
  • FastAPLoss
  • GenericPairLoss
  • GeneralizedLiftedStructureLoss
  • InstanceLoss
  • HistogramLoss
  • IntraPairVarianceLoss
  • LargeMarginSoftmaxLoss
  • LiftedStructureLoss
  • ManifoldLoss
  • MarginLoss
  • MultiSimilarityLoss
  • MultipleLosses
  • NCALoss
  • NormalizedSoftmaxLoss
  • NPairsLoss
  • NTXentLoss
  • P2SGradLoss
  • PNPLoss
  • ProxyAnchorLoss
  • ProxyNCALoss
  • RankedListLoss
  • SelfSupervisedLoss
  • SignalToNoiseRatioContrastiveLoss
  • SoftTripleLoss
  • SphereFaceLoss
  • SubCenterArcFaceLoss
  • SupConLoss
  • TripletMarginLoss
  • TupletMarginLoss
  • WeightRegularizerMixin
  • VICRegLoss

4. 难样本挖掘 Miners

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值