多数度量学习的代码都需要进行挖掘,样本挖掘过程就是把一个Batch中的所有样本,根据标签来划分成正样本和负样本
这里我们只讨论多标签分类问题,标签是onehot编码,如果是单标签分类任务可以去看pytorch_metric_learning这个库有实现好的挖掘方法
比如输入样本为[Batch,Embedding],对应的标签是[Batch,Class]
对这些样本进行挖掘后得到以下三部分:
- Anchor :锚点样本,其实就是和输入的Batch一模一样,
- Positive Sample : 挖掘的正正样本
- Negtive Sample : 挖掘的负样本
import torch
import torch.nn as nn
import torchvision
# 损失函数
class HibCriterion(nn.Module):
def __init__(self):
super().__init__()
def forward(self, z_samples, alpha, beta, indices_tuple):
n_samples = z_samples.shape[1]
if len(indices_tuple) == 3:
a, p, n = indices_tuple
ap = an = a
elif len(indices_tuple) == 4:
ap, p, an, n = indices_tuple
alpha = torch.nn.functional.softplus(alpha)
loss = 0
for i in range(n_samples):
z_i = z_samples[:, i, :]
for j in range(n_samples):
z_j = z_samples[:, j, :]
prob_pos = torch.sigmoid(- alpha * torch.sum((z_i[ap] - z_j[p])**2, dim=1) + beta) + 1e-6
prob_neg = torch.sigmoid(- alpha * torch.sum((z_i[an] - z_j[n])**2, dim=1) + beta) + 1e-6
# maximize the probability of positive pairs and minimize the probability of negative pairs
loss += -torch.log(prob_pos) - torch.log(1 - prob_neg)
loss = loss / (n_samples ** 2)
return loss.mean()
def get_matches_and_diffs(labels):
matches = (labels.float() @ labels.float().T).byte()
diffs = matches ^ 1 # 异或运算得到负标签的矩阵
return matches, diffs
def get_all_triplets_indices_vectorized_method(all_matches, all_diffs):
"""
Args:
all_matches (torch.Tensor): 相同标签
all_diffs (torch.Tensor): 不相同标签
Processing : all_matches.unsqueeze(2) -> [Batch,Batch,1]
all_diffs.unsqeeeze(1) -> [Batch,1,Batch]
Returns:
torch.Tensor: _description_
"""
triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)
return torch.where(triplets)
class TripletMinner(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.sim_mat = get_matches_and_diffs
self.selctor = get_all_triplets_indices_vectorized_method
def forward(self,labels):
a , b = self.sim_mat(labels)
c = self.selctor(a,b)
return c