[center loss] demo

看到 center loss 可以像聚类一样,使用不同的核函数(可以看成是计算距离的函数)计算loss,记录一下这个demo。


import os
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
from torch import nn


class CenterLoss(nn.Module):
    """Center loss.
    
    Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
    
    Args:
        num_classes (int): number of classes.
        feat_dim (int): feature dimension.
    """
    def __init__(self, num_classes, feat_dim):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
 
    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (batch_size).
        """
        device = x.device
        batch_size = x.size(0)
        if device != self.centers.device:
            self.centers = self.centers.to(device)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat.addmm_(1, -2, x, self.centers.t())
 
        classes = torch.arange(self.num_classes).long().to(device)
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, self.num_classes))

        dist = distmat * mask.float()
        loss = dist.clamp(min=0, max=1e12).sum() / batch_size
 
        return loss
    


os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.chdir(os.path.dirname(__file__) + "/../")
cLoss  = CenterLoss(10, 128)
mnist_train = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
mnist_test = datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor(), download=True)


net = nn.Sequential(
    nn.Linear(28*28, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 128)
)


net = net.cuda()
cLoss = cLoss.cuda()

optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
optimizer_center = torch.optim.SGD(cLoss.parameters(), lr=0.5)

train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=64, shuffle=False)

for epoch in range(50):
    net.train()
    cLoss.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28).cuda()
        target = target.cuda()
        feature = net(data)
        loss1 = nn.CrossEntropyLoss()(feature, target)
        loss2 = cLoss(feature, target)
        loss = loss1 + 0.1 * loss2
        optimizer.zero_grad()
        optimizer_center.zero_grad()
        loss.backward()
        optimizer.step()
        optimizer_center.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

    net.eval()
    cLoss.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data = data.view(-1, 28*28).cuda()
            target = target.cuda()
            feature = net(data)
            test_loss += nn.CrossEntropyLoss()(feature, target).item()
            pred = feature.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    


  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Center Loss是一种用于人脸识别和特征学习的损失函数。然而,Center Loss不能直接使用的原因是它需要考虑整个训练集,并在每次迭代中平均每个类的特征,这是低效的。因此,为了解决这个问题,我们需要对Center Loss进行一些改进。 一种解决方法是使用Mini-batch Center Loss,它在每个小批量数据上计算每个类别的特征中心。具体来说,对于每个类别,在小批量数据中统计该类别的特征,并计算出该类别的特征中心。然后,通过最小化特征与其对应类别的特征中心之间的距离,来更新特征中心。这样,我们可以在每个小批量数据上更新特征中心,而不需要考虑整个训练集。 另一种解决方法是使用在线更新的Center Loss。在线更新的Center Loss只在每个样本的前向传播过程中计算特征中心,并在反向传播过程中更新特征中心。这样,我们可以在每个样本上更新特征中心,而不需要在每次迭代中平均每个类的特征。 综上所述,为了解决Center Loss不能直接使用的问题,我们可以使用Mini-batch Center Loss或在线更新的Center Loss来改进Center Loss的效率和实用性。 #### 引用[.reference_title] - *1* *2* [CenterLoss原理详解(通透)](https://blog.csdn.net/weixin_54546190/article/details/124504683)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [CenterLoss | 减小类间距离](https://blog.csdn.net/qiu931110/article/details/106108936)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值