参照pytorch官网教程,了解了一下KL散度的调用方法,做个记录。
https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
计算公式:
而调用代码时torch.nn.KLDivLoss
和F.kl_div
是一样的,前者是一个类,在实现方法里调用了后者方法。
To avoid underflow issues when computing this quantity, this loss
expects the argument input in the log-space. The argument target may
also be provided in the log-space iflog_target
= True
需要传两个参数KLDivLoss(input,target)
。这里是用targlet
指导input
。上面说明了input
在传进去之前,需要做一个log
处理,很关键。传了log
之后,就不再是softmax
的概率和为1了,相当于log softmax
都成了负数。而target
传与不传log
后的都可以,有个target_log
参数,默认为False
。然后执行下述计算。
计算方法严格按照公式来,只是target
是否是log
,导致计算方式不同。
官方例子很清楚:
默认就是将input处理为log softmax
型,而target
保持原softmax
概率分布,用target
指导input
(虽然变为了log
,但只是为了计算,指导的还是变之前的概率分布)。或者target
也变为了log
,需要用log_target
声明。