Center Loss的Pytorch实现


Center Loss的Pytorch实现: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
这个损失函数也被使用在: deep-person-reid
github项目: https://github.com/KaiyangZhou/pytorch-center-loss

开始

Clone this repo and run the code.

$ git clone https://github.com/KaiyangZhou/pytorch-center-loss
$ cd pytorch-center-loss
$ python main.py --eval-freq 1 --gpu 0 --save-dir log/ --plot

You will see the following info in your terminal.

Currently using GPU: 0
Creating dataset: mnist
Creating model: cnn
==> Epoch 1/100
Batch 50/469     Loss 2.332793 (2.557837) XentLoss 2.332744 (2.388296) CenterLoss 0.000048 (0.169540)
Batch 100/469    Loss 2.354638 (2.463851) XentLoss 2.354637 (2.379078) CenterLoss 0.000001 (0.084773)
Batch 150/469    Loss 2.361732 (2.434477) XentLoss 2.361732 (2.377962) CenterLoss 0.000000 (0.056515)
Batch 200/469    Loss 2.336701 (2.417842) XentLoss 2.336700 (2.375455) CenterLoss 0.000001 (0.042386)
Batch 250/469    Loss 2.404814 (2.407015) XentLoss 2.404813 (2.373106) CenterLoss 0.000001 (0.033909)
Batch 300/469    Loss 2.338753 (2.398546) XentLoss 2.338752 (2.370288) CenterLoss 0.000001 (0.028258)
Batch 350/469    Loss 2.367068 (2.390672) XentLoss 2.367059 (2.366450) CenterLoss 0.000009 (0.024221)
Batch 400/469    Loss 2.344178 (2.384820) XentLoss 2.344142 (2.363620) CenterLoss 0.000036 (0.021199)
Batch 450/469    Loss 2.329708 (2.379460) XentLoss 2.329661 (2.360611) CenterLoss 0.000047 (0.018848)
==> Test
Accuracy (%): 10.32  Error rate (%): 89.68
... ...
==> Epoch 30/100
Batch 50/469     Loss 0.141117 (0.155986) XentLoss 0.084169 (0.091617) CenterLoss 0.056949 (0.064369)
Batch 100/469    Loss 0.138201 (0.151291) XentLoss 0.089146 (0.092839) CenterLoss 0.049055 (0.058452)
Batch 150/469    Loss 0.151055 (0.151985) XentLoss 0.090816 (0.092405) CenterLoss 0.060239 (0.059580)
Batch 200/469    Loss 0.150803 (0.153333) XentLoss 0.092857 (0.092156) CenterLoss 0.057946 (0.061176)
Batch 250/469    Loss 0.162954 (0.154971) XentLoss 0.094889 (0.092099) CenterLoss 0.068065 (0.062872)
Batch 300/469    Loss 0.162895 (0.156038) XentLoss 0.093100 (0.092034) CenterLoss 0.069795 (0.064004)
Batch 350/469    Loss 0.146187 (0.156491) XentLoss 0.082508 (0.091787) CenterLoss 0.063679 (0.064704)
Batch 400/469    Loss 0.171533 (0.157390) XentLoss 0.092526 (0.091674) CenterLoss 0.079007 (0.065716)
Batch 450/469    Loss 0.209196 (0.158371) XentLoss 0.098388 (0.091560) CenterLoss 0.110808 (0.066811)
==> Test
Accuracy (%): 98.51  Error rate (%): 1.49
... ...

Please run python main.py -h for more details regarding input arguments.

结果

We visualize the feature learning process below.
Softmax only. Left: training set. Right: test set.
在这里插入图片描述
Softmax + center loss. Left: training set. Right: test set.
在这里插入图片描述

在自己的项目中使用中心损失函数

  1. All you need is the center_loss.py file
from center_loss import CenterLoss
  1. Initialize center loss in the main function
center_loss = CenterLoss(num_classes=10, feat_dim=2, use_gpu=True)
  1. Construct an optimizer for center loss
optimizer_centloss = torch.optim.SGD(center_loss.parameters(), lr=0.5)

Alternatively, you can merge optimizers of model and center loss, like

params = list(model.parameters()) + list(center_loss.parameters())
optimizer = torch.optim.SGD(params, lr=0.1) # here lr is the overall learning rate

  1. Update class centers just like how you update a pytorch model
# features (torch tensor): a 2D torch float tensor with shape (batch_size, feat_dim)
# labels (torch long tensor): 1D torch long tensor with shape (batch_size)
# alpha (float): weight for center loss
loss = center_loss(features, labels) * alpha + other_loss
optimizer_centloss.zero_grad()
loss.backward()
# multiple (1./alpha) in order to remove the effect of alpha on updating centers
for param in center_loss.parameters():
    param.grad.data *= (1./alpha)
optimizer_centloss.step()

If you adopt the second way (i.e. use one optimizer for both model and center loss), the update code would look like

loss = center_loss(features, labels) * alpha + other_loss
optimizer.zero_grad()
loss.backward()
for param in center_loss.parameters():
    # lr_cent is learning rate for center loss, e.g. lr_cent = 0.5
    param.grad.data *= (lr_cent / (alpha * lr))
optimizer.step()
  • 5
    点赞
  • 51
    收藏
    觉得还不错? 一键收藏
  • 11
    评论
CenterLoss是一种用于增强深度学习模型分类性能的损失函数,它通过学习每个类别的中心点来减小样本与其对应类别中心之间的距离。在PyTorch中,有不同的实现方式。根据提供的引用,可以看到有两种不同的实现方式。 引用展示了一种简单直观的实现方式,定义了一个名为CenterLoss2的类,其中包含了一个num_class x num_feature维的参数centers,代表每个类别的中心点。在forward方法中,通过计算样本与其对应类别中心之间的距离dist,然后对dist进行限制和平均操作,得到最终的损失值loss。 引用展示了另一种实现方式,其中通过使用一个mask来选择与每个样本对应的中心点,然后计算样本与其对应类别中心之间的距离dist,并对dist进行平均操作,得到最终的损失值loss。 此外,引用展示了计算样本与每个中心点之间距离的方法,通过对样本和中心点进行平方和操作,得到一个距离矩阵distmat。 因此,根据不同的实现方式和计算方法,可以使用CenterLoss来优化深度学习模型的分类性能。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [center loss pytorch实现总结](https://blog.csdn.net/qq_45759229/article/details/126917939)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *2* *3* [【LossCenter loss代码详解(pytorch)](https://blog.csdn.net/m0_51358406/article/details/122312950)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值