center loss pytorch实现总结

实现1

参考
https://github.com/KaiyangZhou/pytorch-center-loss


# raw implement
import torch
import torch.nn as nn

class CenterLoss(nn.Module):
    """Center loss.
    
    Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
    
    Args:
        num_classes (int): number of classes.
        feat_dim (int): feature dimension.
    """
    def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.use_gpu = use_gpu

        if self.use_gpu:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
        else:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (batch_size).
        """
        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat.addmm_(1, -2, x, self.centers.t())

        classes = torch.arange(self.num_classes).long()
        if self.use_gpu: classes = classes.cuda()
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, self.num_classes))

        dist = distmat * mask.float()
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size

        return loss
# scJoint
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np


class CenterLoss(nn.Module):
    def __init__(self, num_classes=20, feat_dim=64, use_gpu=True):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.use_gpu = use_gpu

        if self.use_gpu:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
        else:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

    def forward(self, embeddings, labels):
        center_loss = 0
        for i, x in enumerate(embeddings):
            label = labels[i].long()
            batch_size = x.size(0)
            distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                      torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
            distmat.addmm_(1, -2, x, self.centers.t())
            distmat = torch.sqrt(distmat)

            classes = torch.arange(self.num_classes).long()
            if self.use_gpu: classes = classes.cuda()
            label = label.unsqueeze(1).expand(batch_size, self.num_classes)
            mask = label.eq(classes.expand(batch_size, self.num_classes))

            dist = distmat * mask.float()
            center_loss += torch.mean(dist.clamp(min=1e-12, max=1e+12))
        
        #enter_loss = center_loss/len(embeddings) # 其实这个长度就是1,可以不用除
        return center_loss
    

torch.manual_seed(42)
num_class=10
num_feature=64
num_sample=256


center_loss=CenterLoss(num_classes=num_class,feat_dim=num_feature,use_gpu=False)
embedding=np.random.randn(num_sample,num_feature)
label=np.random.randint(0,num_class,size=num_sample)

embeddings=[torch.FloatTensor(embedding)]
labels= [torch.LongTensor(label)]

print(center_loss(embeddings,labels))

#print(center_loss.centers)

结果如下
在这里插入图片描述

实现2(好理解)

class CenterLoss2(nn.Module):
    def __init__(self, num_class=10, num_feature=2):
        super(CenterLoss2, self).__init__()
        self.num_class = num_class
        self.num_feature = num_feature
        self.centers = nn.Parameter(torch.randn(self.num_class, self.num_feature))

    def forward(self, x, labels):
        print("Centerloss2")
        center = self.centers[labels]
        dist = (x-center).pow(2).sum(dim=-1)
        #########################
        dist = torch.sqrt(dist)
        #########################
        loss = torch.clamp(dist, min=1e-12, max=1e+12).mean(dim=-1)

        return loss


torch.manual_seed(42)
center_loss2=CenterLoss2(num_class=num_class,num_feature=num_feature)    

embeddings=torch.FloatTensor(embedding)
labels= torch.LongTensor(label)

print(center_loss2(embeddings,labels)/num_class)
#print(center_loss2.centers)

结果如下
在这里插入图片描述可以看到结果一幕一样,就不用管第一种实现了

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值