【多标签分类问题的样本挖掘】Pytorch中的TripletMarginLoss的样本挖掘

多数度量学习的代码都需要进行挖掘,样本挖掘过程就是把一个Batch中的所有样本,根据标签来划分成正样本和负样本
这里我们只讨论多标签分类问题,标签是onehot编码,如果是单标签分类任务可以去看pytorch_metric_learning这个库有实现好的挖掘方法
比如输入样本为[Batch,Embedding],对应的标签是[Batch,Class]
对这些样本进行挖掘后得到以下三部分:

  1. Anchor :锚点样本,其实就是和输入的Batch一模一样,
  2. Positive Sample : 挖掘的正正样本
  3. Negtive Sample : 挖掘的负样本
import torch
import torch.nn as nn 
import torchvision

# 损失函数
class HibCriterion(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, z_samples, alpha, beta, indices_tuple):
        
        n_samples = z_samples.shape[1]
                
        if len(indices_tuple) == 3:
            a, p, n = indices_tuple
            ap = an = a
        elif len(indices_tuple) == 4:
            ap, p, an, n = indices_tuple
        
        alpha = torch.nn.functional.softplus(alpha)

        loss = 0
        for i in range(n_samples):
            z_i = z_samples[:, i, :]
            for j in range(n_samples):
                z_j = z_samples[:, j, :]
        
                prob_pos = torch.sigmoid(- alpha * torch.sum((z_i[ap] - z_j[p])**2, dim=1) + beta) + 1e-6
                prob_neg = torch.sigmoid(- alpha * torch.sum((z_i[an] - z_j[n])**2, dim=1) + beta) + 1e-6
                
                # maximize the probability of positive pairs and minimize the probability of negative pairs
                loss += -torch.log(prob_pos) - torch.log(1 - prob_neg)
        loss = loss / (n_samples ** 2)
        
        return loss.mean()

def get_matches_and_diffs(labels):
    matches = (labels.float() @ labels.float().T).byte()
    diffs = matches ^ 1 # 异或运算得到负标签的矩阵
    return matches, diffs

def get_all_triplets_indices_vectorized_method(all_matches, all_diffs):
    """
    Args:
        all_matches (torch.Tensor): 相同标签
        all_diffs (torch.Tensor): 不相同标签

    Processing : all_matches.unsqueeze(2) -> [Batch,Batch,1]
                 all_diffs.unsqeeeze(1) -> [Batch,1,Batch] 

    Returns:
        torch.Tensor: _description_
    """
    
    triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)
    return torch.where(triplets)

class TripletMinner(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.sim_mat = get_matches_and_diffs
        self.selctor = get_all_triplets_indices_vectorized_method

    def forward(self,labels):
        
        a , b = self.sim_mat(labels)
        c = self.selctor(a,b)

        return c
        
  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Dou_Huanmin

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值