triplet_loss.py 实现了深度学习中的三元组损失(Triplet Loss)和一些相关的辅助函数。以下是详细注释:
导入必要的库
from turtle import pd
import torch
from torch import nn
定义归一化函数
def normalize(x, axis=-1):
"""归一化到单位长度,沿指定维度进行。
Args:
x: PyTorch 变量
Returns:
x: PyTorch 变量,形状与输入相同
"""
x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
return x
定义欧氏距离函数
def euclidean_dist(x, y):
"""
计算两个张量之间的欧氏距离。
Args:
x: PyTorch 变量,形状 [m, d]
y: PyTorch 变量,形状 [n, d]
Returns:
dist: PyTorch 变量,形状 [m, n]
"""
m, n = x.size(0), y.size(0)
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) # m x 1 -> m x n
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() # n x 1 -> n x m, 然后转置
dist = xx + yy
dist = dist - 2 * torch.matmul(x, y.t())
dist = dist.clamp(min=1e-12).sqrt() # 数值稳定性
return dist
定义余弦距离函数
def cosine_dist(x, y):
"""
计算两个张量之间的余弦距离。
Args:
x: PyTorch 变量,形状 [m, d]
y: PyTorch 变量,形状 [n, d]
Returns:
dist: PyTorch 变量,形状 [m, n]
"""
m, n = x.size(0), y.size(0)
x_norm = torch.pow(x, 2).sum(1, keepdim=True).sqrt().expand(m, n)
y_norm = torch.pow(y, 2).sum(1, keepdim=True).sqrt().expand(n, m).t()
xy_intersection = torch.mm(x, y.t())
dist = xy_intersection / (x_norm * y_norm)
dist = (1. - dist) / 2
return dist
定义难例挖掘函数
def hard_example_mining(dist_mat, labels, return_inds=False):
"""对每个锚点找到最难的正样本和负样本。
Args:
dist_mat: PyTorch 变量,样本之间的距离矩阵,形状 [N, N]
labels: PyTorch 长张量,形状 [N]
return_inds: 是否返回索引。若为 `False`,节省时间(?)
Returns:
dist_ap: PyTorch 变量,距离(锚点,正样本);形状 [N]
dist_an: PyTorch 变量,距离(锚点,负样本);形状 [N]
p_inds: PyTorch 长张量,形状 [N];
选择的最难正样本的索引;0 <= p_inds[i] <= N - 1
n_inds: PyTorch 长张量,形状 [N];
选择的最难负样本的索引;0 <= n_inds[i] <= N - 1
注意:仅考虑所有标签具有相同数量样本的情况,
以便我们可以并行处理所有锚点。
"""
assert len(dist_mat.size()) == 2
assert dist_mat.size(0) == dist_mat.size(1)
N = dist_mat.size(0)
# 形状 [N, N]
is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
# `dist_ap` 表示距离(锚点,正样本)
# `dist_ap` 和 `relative_p_inds` 的形状均为 [N, 1]
dist_ap, relative_p_inds = torch.max(
dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
# `dist_an` 表示距离(锚点,负样本)
# `dist_an` 和 `relative_n_inds` 的形状均为 [N, 1]
dist_an, relative_n_inds = torch.min(
dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
# 形状 [N]
dist_ap = dist_ap.squeeze(1)
dist_an = dist_an.squeeze(1)
if return_inds:
# 形状 [N, N]
ind = (labels.new().resize_as_(labels)
.copy_(torch.arange(0, N).long())
.unsqueeze(0).expand(N, N))
# 形状 [N, 1]
p_inds = torch.gather(
ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
n_inds = torch.gather(
ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
# 形状 [N]
p_inds = p_inds.squeeze(1)
n_inds = n_inds.squeeze(1)
return dist_ap, dist_an, p_inds, n_inds
return dist_ap, dist_an
定义三元组损失类
class TripletLoss(object):
"""
使用更难样本挖掘的三元组损失,
基于原始三元组损失修改。
"""
def __init__(self, margin=None, hard_factor=0.0):
self.margin = margin
self.hard_factor = hard_factor
if margin is not None:
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
else:
self.ranking_loss = nn.SoftMarginLoss()
def __call__(self, global_feat, labels, normalize_feature=False):
if normalize_feature:
global_feat = normalize(global_feat, axis=-1)
dist_mat = euclidean_dist(global_feat, global_feat) # 计算全局特征的距离矩阵
dist_ap, dist_an = hard_example_mining(dist_mat, labels) # 挖掘难样本
dist_ap *= (1.0 + self.hard_factor)
dist_an *= (1.0 - self.hard_factor)
y = dist_an.new().resize_as_(dist_an).fill_(1)
if self.margin is not None:
loss = self.ranking_loss(dist_an, dist_ap, y)
else:
loss = self.ranking_loss(dist_an - dist_ap, y)
return loss, dist_ap, dist_an
总结
- 归一化函数:将输入张量归一化到单位长度。
- 欧氏距离函数:计算两个张量之间的欧氏距离矩阵。
- 余弦距离函数:计算两个张量之间的余弦距离矩阵。
- 难例挖掘函数:找到每个样本中最难的正样本和负样本,返回它们的距离。
- 三元组损失类:使用更难样本挖掘的三元组损失函数,支持可选的边界(margin)和硬度因子(hard factor)。
这些函数和类共同实现了一个复杂的深度学习损失函数,用于改进特征学习的效果。