PyTorch自定义Loss跟写Model差不多,都是继承nn.Module
,也有的直接定义一个函数计算(一笔带过)。
考虑比较几个不常用的Loss:MarginRankingLoss,HingeEmbeddingLoss,CosineEmbeddingLoss
HingeEmbeddingLoss Docs
The loss function for : n-th sample in the mini-batch is
l n = { x n , if y n = 1 , max { 0 , Δ − x n } , if y n = − 1 , l_n = \begin{cases} x_n, & \text{if}\; y_n = 1,\\ \max \{0, \Delta - x_n\}, & \text{if}\; y_n = -1, \end{cases} ln={xn,max{0,Δ−xn},ifyn=1,ifyn=−1,
and the total loss functions is
ℓ ( x , y ) = { mean ( L ) , if reduction = ’mean’; sum ( L ) , if reduction = ’sum’. \ell(x, y) = \begin{cases} \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} \end{cases} ℓ(x,y)={mean(L),sum(L),if reduction=’mean’;if reduction=’sum’.
where
L = { l 1 , … , l N } ⊤ L = \{l_1,\dots,l_N\}^\top L={l1,…,lN}⊤
- PyTorch库
import torch
import torch.nn as nn
hinge = nn.HingeEmbeddingLoss(margin=1., reduction='mean')
x = torch.FloatTensor([[0.2, 0.4, 1]])
y = torch.LongTensor([[-1, 1, -1]])
loss = hinge(x, y)
- 自定义
import torch
import torch.nn as nn
import time
import random
class HingeEmbeddingLoss(nn.Module):
def __init__(self, margin=1., reduction='mean'):
super().__init__()
self.margin = margin
self.reduction = reduction
# self.hinge = nn.HingeEmbeddingLoss(margin=margin, reduction='mean')
def forward(self, x, label):
# hinge_label = (- torch.ones_like(x)).scatter(dim=1, index=label.view(-1, 1), src=torch.ones_like(x))
# print(hinge_label)
# loss = self.hinge(x, hinge_label)
gt = x * torch.zeros_like(x).scatter(dim=1, index=label.view(-1, 1), src=torch.ones_like(x))
loss = ((self.margin - x).clamp_min(min=0) * torch.ones_like(x).scatter(dim=1, index=label.view(-1, 1), src=torch.zeros_like(x)) + gt)
if self.reduction == 'mean':
return loss.mean()
else:
return loss.sum()
if __name__ == '__main__':
# 选择一个大尺寸来比较时间消耗
x = torch.randn((128, 4096))
y = torch.LongTensor(random.choices(range(299), k=128))
loss = HingeEmbeddingLoss()
print(loss(x, y))
HingeLoss的逻辑很简单,就是y=1的时候保持原本的score,y!=1的时候与margin比较,大于margin置0,小于margin就取差。官方的函数需要把标签转换成1,-1的形式,相当于one-hot编码把0那部分替换成-1,但是实际实现可以直接scatter做选择。但是x的值在运算过程中选择始终会变化,这里实现采取的一种比较繁琐的方式,加mask取y=1的score,然后覆盖掉原x的位置,最后再加回去。变化的原因暂时没有深究,这种实现的时间上以上测试用例,在CPU运行差异在0.02~0.05s左右。
MarginRankingLoss
CosineEmbeddingLoss
待续…