本文主要说明 PyTorch 损失函数封装中 size_average、reduce 和 reduction 三个参数的意义.
PyTorch 中提供损失函数的 类封装 torch.nn.modules.loss
和 函数封装 torch.nn.functional
. 他们有 size_average
、size_average
、reduction
三个参数,这三个参数理解起来有些复杂,做以下纪录.
-
size_average
(布尔类型, 可选参数)
已过时(Deprecated)(见 reduction). 一般地,losses 损失函数值为 batch 中对所有 loss 元素的平均值. 这里注意,对有些类型的损失函数,在单个训练样本中存在多个元素. 如果size_average
域设为False
,losses 损失函数值为 minibatch 中对所有 loss 元素的求和. 当reduce
设为False
时,忽略size_average
域. 缺省为:True
. -
reduce
(布尔类型, 可选参数)
已过时(Deprecated)(见 reduction). 一般地,losses 损失函数值为 minibatch 中对所有 loss 张量元素的平均值或求和,这取决于size_average
域的设置. 当reduce
为False
,返回 batch 中每个样本的 loss 值,并忽略size_average
. 缺省为:True
. -
reduction
(字符串类型, 可选参数) ’
确定对 loss 输出结果应用 reduction 的类型:‘none’
|‘mean’
|‘sum’
. 注意,size_average
和reduce
将在后续版本中被弃用(being deprecated),但与此同时,这两个参数的设置将覆盖reduction
. 缺省为:'mean'
.‘none’
:无 reduction 被应用.‘mean’
:对输出结果求和并除以输出结果张量中的元素个数.‘sum’
:对输出结果求和.
以上内容翻译自 PyTorch 官方文档,但是并不易于理解. 简单来说:
reduce
决定是求整个 batch 的 loss 值,还是求 batch 中每个 sample 的 loss 值. 默认为 True
求整个 batch 的 loss 值.
size_average
决定 loss 是求平均还是求和. 默认为 True
求平均. 并且当为 False
时,忽略 reduce
的设置.
reduction
的作用等同于 size_average
+ reduce
. ‘none’
为求 minibatch 中每个 sample 的 loss 值. ‘mean’
为求整个 minibatch 的 loss 值,对 minibatch 中所有 sample 的 loss 值求平均. ‘sum’
为求整个 minibatch 的 loss 值, 对 minibatch 中所有 sample 的 loss 值求和.