!转载请注明原文地址!——东方旅行者
更多行人重识别文章移步我的专栏:行人重识别专栏
难样本挖掘三元组损失(TriHard_Loss.py)
一、难样本挖掘三元组损失作用
用于计算度量损失,与表征学习阶段分类损失协同使用反向传播优化网络参数,且难样本三元组损失有利于网络学到更好的特征。
二、难样本挖掘三元组损失编写思路
在实现难样本挖掘三元损失时借助相似度矩阵进行计算,将批次图片按照顺序形成一个大小为P×K的方阵,方阵元素(0,2)代表第0张图片与第二张图片的相似度。如下图所示(图片来自浙江大学罗浩博士教学视频)
红色区域为每个行人与各自正样本之间的距离,而绿色区域为每个行人与各自负样本的距离。将该矩阵进行变换,将红色区域移到同一侧,绿色区域同一侧,对红色区域按行求最大值得到正样本最大距离向量(P×K,1),对绿色区域按行求最小值得到负样本最小距离向量(P*K,1),得到这两个向量即可根据公式计算难样本挖掘三元组损失。
编写代码时,计算损失需要继承nn.Module,重写__init__方法与forward方法。
__init__方法需要传入margin,并调用MarginRankingLoss计算三元组损失。
三、代码
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from IPython import embed
"""
本文件用于自定义难样本挖掘三元组损失,定义难样本挖掘三元组损失计算过程。
"""
class TripletLoss(nn.Module):
def __init__(self, margin=0.3):
super().__init__()
self.margin=margin
#计算三元组损失使用的函数
self.ranking_loss=nn.MarginRankingLoss(margin=margin)
def forward(self, inputs, targets):
n=inputs.size(0)
"""
计算图片之间的欧氏距离
矩阵A,B欧氏距离等于√(A^2 + (B^T)^2 - 2A(B^T))
"""
#计算A^2
distance=torch.pow(inputs,2).sum(dim=1, keepdim=True).expand(n,n)
#计算A^2 + (B^T)^2
distance=distance+distance.t()
#计算A^2 + (B^T)^2 - 2A(B^T)
distance.addmm(1,-2,inputs,inputs.t())
#计算√(A^2 + (B^T)^2 - 2A(B^T))
distance=distance.clamp(min=1e-12).sqrt()#该distance矩阵为对称矩阵
#获取对角线
mask=targets.expand(n,n)==targets.expand(n,n).t()#mask矩阵用于区分红绿色区域,即正样本区与负样本区,便于进行损失计算。
#list类型
distance_ap,distance_an=[],[]
for i in range(n):
distance_ap.append(distance[i][mask[i]].max().unsqueeze(0))#distance[i][mask[i]]使distance保留正样本区
distance_an.append(distance[i][mask[i]==0].min().unsqueeze(0))#distance[i][mask[i]==0]使distance保留负样本区
#经过for循环后,正样本最大距离与负样本最小距离都存储在list当中,需要将list元素连接成一个torch张量
distance_ap=torch.cat(distance_ap)
distance_an=torch.cat(distance_an)
#y指明ranking_loss前一个参数大于后一个参数
y=torch.ones_like(distance_an)
loss=self.ranking_loss(distance_an, distance_ap, y)
return loss
if __name__=='__main__':
target=[1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,6,6,6,6,7,7,7,7,8,8,8,8]
target=torch.Tensor(target)
features=torch.rand(32,2048)
a=TripletLoss()
loss=a.forward(features,target)
print(loss)
四、测试结果
tensor(0.8285)