构造中心损失----pytorch详解

当输入数据X维度为[num_classes,feat_dim]时,参考链接: Center loss-pytorch代码详解.

对于输入数据X类型为[batch_size,seq_len,feat_dim],对参考链接代码进行调整,整个代码如下:

class CenterLoss_seq(nn.Module):
    """Center loss.

    Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.

    Args:
        num_classes (int): number of classes.
        feat_dim (int): feature dimension.
    """

    def __init__(self, batch_size, num_classes, feat_dim, use_gpu=True):
        super(CenterLoss_seq, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.use_gpu = use_gpu
        self.batch_size = batch_size

        if self.use_gpu:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
        else:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

        # 计算中心矩阵的转置,以便在forward方法中使用矩阵乘法进行高效计算
        self.center_t = self.centers.t().expand(self.batch_size, self.feat_dim, self.num_classes).contiguous()

    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, seq_len, feat_dim).
            labels: ground truth labels with shape (batch_size, seq_len).
        """
        batch_size = x.size(0)
        seq_len = x.size(1)

        x_pow = torch.pow(x, 2).sum(dim=2, keepdim=True).expand(batch_size, seq_len, self.num_classes)

        center_pow = torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        center_pow = center_pow.unsqueeze(1).expand(batch_size, seq_len, self.num_classes)

        # distmat = torch.pow(x, 2).sum(dim=2, keepdim=True).expand(batch_size, seq_len, self.num_classes) + \
        #           torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat = x_pow + center_pow

        # center_t = self.centers.t().expand(batch_size, self.feat_dim, self.num_classes)
        # x_cen_mul = torch.bmm(x.float(), center_t.float())

        x_cen_mul = torch.matmul(x.float(), self.center_t.float())  # 使用矩阵乘法进行高效计算

        distmat = distmat - 2 * x_cen_mul

        # distmat.addmm_(1, -2, x, self.centers.t())

        classes = torch.arange(self.num_classes).long()
        if self.use_gpu: classes = classes.cuda()
        labels = labels.unsqueeze(2).expand(batch_size, seq_len, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, seq_len, self.num_classes))

        dist = distmat * mask.float()
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / (batch_size * seq_len)

        return loss

参数
num_classes:为数据集类别数
feat_dim:为特征向量维度
batch_size:为小批量样本数目
seq_len:为序列长度

下面对forward代码进行解析:
举例说明,这里假设num_classes=5。feat_dim=4,batch_size=1,seq_len=2。
输入X:[batch_size,seq_len,feat_dim]=[1,2,4]。
labels:[batch_size,seq_len]=[1,2]。
经过初始化,centers:[num_classes,feat_dim]。centers_t:[batch_size,feat_dim,num_classes]是centers经过转置并扩展得到。
经过下面代码运行,得到x_pow。

		x_pow = torch.pow(x, 2).sum(dim=2, keepdim=True).expand(batch_size, seq_len, self.num_classes)

对上行代码进行解析:
x:[batch_size,seq_len,feat_dim]=[1,2,4]

其中每个向量的维度为feat_dim=4。
得到x_pow :[batch_size,seq_len,num_classes]

接着同理对centers进行操作:

		center_pow = torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        center_pow = center_pow.unsqueeze(1).expand(batch_size, seq_len, self.num_classes)

        # distmat = torch.pow(x, 2).sum(dim=2, keepdim=True).expand(batch_size, seq_len, self.num_classes) + \
        #           torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat = x_pow + center_pow

得到center_pow:[batch_size,seq_len, num_classes]=[1,2,5]。

然后得到distmat:[batch_size,seq_len, num_classes]=[1,2,5]。
在这里插入图片描述
再经过如下两行代码进行矩阵运算,得到:

 		x_cen_mul = torch.matmul(x.float(), self.center_t.float())  # 使用矩阵乘法进行高效计算

        distmat = distmat - 2 * x_cen_mul

x_cen_mul:[batch_size,seq_len,num_classes]=[1,2,5]。

得到最终的distmat为:
在这里插入图片描述
经过如下代码来得到mask。

		classes = torch.arange(self.num_classes).long()
        if self.use_gpu: classes = classes.cuda()
        labels = labels.unsqueeze(2).expand(batch_size, seq_len, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, seq_len, self.num_classes))

标签为labels:[batch_size,seq_len]=[1,2]。假设便签值为[0,1],经历过扩展后,labels:[batch_size,seq_len,num_classes],标签值为:

classes=[0,1,2,3,4]经过扩展后,其维度为:[batch_size,seq_len,num_classes]。

得到mask:[batch_size,seq_len,num_classes]。

 		dist = distmat * mask.float()
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / (batch_size * seq_len)

将distmat与mask对应位置相乘,正好得到输入数据X与其对应中心的距离平方,最后得到平均距离之和,clamp使其在[1e-12,1e+12]的范围。内。

  • 11
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值