NLL_LOSS, CROSS_ENTROPY_LOSS详解
常用损失函数
nll
官方文档
torch. nn. functional. nll_loss( input , target, weight= None , size_average= None ,
ignore_index= - 100 , reduce = None , reduction= 'mean' )
i
n
p
u
t
,
(
N
,
C
)
;
t
a
r
g
e
t
,
(
N
)
;
w
e
i
g
h
t
,
(
C
)
input, (N, C); \ target,(N);\ weight,(C)
in p u t , ( N , C ) ; t a r g e t , ( N ) ; w e i g h t , ( C )
N
,
C
N,C
N , C 分别为批大小batch_size 和类别数class_num reduction及函数含义 :选取target 对应下标的结果,在batch 维,求和 (reduction=‘sum’)或求平均 (mean)或不操作 (none)作为结果。 size_average和reduce作用类似,应该不用管。返回值形状为
(
1
)
(1)
( 1 ) 或
(
N
)
(
n
o
n
e
情况下
)
(N) (none情况下)
( N ) ( n o n e 情况下 ) 。 ignore_index为忽略指定下标的值,不参与计算以及梯度传递。
cross_entropy
torch. nn. functional. cross_entropy( input , target, weight= None , size_average= None ,
ignore_index= - 100 , reduce = None , reduction= 'mean' , label_smoothing= 0.0 )
参数和返回值与NLL_LOSS基本一致。 含义 :nll取下标的结果先softmax后取对数,并加负号。
−
l
o
g
(
s
o
f
t
m
a
x
(
n
l
l
(
X
)
)
)
-log(softmax(nll(X)))
− l o g ( so f t ma x ( n ll ( X ))) 此处softmax即exp(下标结果)与∑exp(每个结果)比值。 log即ln自然对数。解释 :softmax首先得到一个类似概率的结果,取值(0,1),然后取对数为一个负值结果,在加上符号得到一个正的损失值。 当下标对应值结果大时,softmax结果更接近1,取log的负值越接近0,损失值也就越小(越接近0)。 同理,下标对应值小时,softmax结果接近0,取log的结果为负无穷,损失值为正无穷。
一段代码
def loss_fn ( out, tar) :
out = out. view( - 1 , out. shape[ - 1 ] )
tar = tar. view( - 1 )
return F. cross_entropy( out, tar, ignore_index= 2 )
transformer中的一段经典代码
此处比较结果out与tar目标值,来计算损失。 变量形状 :out为
(
N
,
L
,
C
)
(N, L, C)
( N , L , C ) ,tar为
(
N
,
L
)
(N, L)
( N , L ) 。其中L为句子长度,需要针对每个词计算损失。这里在使用前先将out和tar通过view转为了
(
N
∗
L
,
C
)
(N*L,C)
( N ∗ L , C ) 和
(
N
∗
L
)
(N*L)
( N ∗ L ) 形状,然后再来计算每个词的损失。 不转的话,按照函数定义,需要将out中的C转到第二维(中间)(N,C,L)。 ignore_index为忽略下标为2的tar结果,此处为pad填充,即transformer中mask掉的值。 最终结果为每个词的损失的平均值。