在训练时我看到有人会添加AverageMeter()的epoch_loss,
一般放在utils.py中,源码如下
rom __future__ import division, absolute_import
__all__ = ['AverageMeter']
class AverageMeter(object):
"""Computes and stores the average and current value.
Examples::
>>> # Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # Update meter after every minibatch update
>>> losses.update(loss_value, batch_size)
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
在pytorch中用utils包来更新得分、损失等等。
代码说的例子是输入有两个参数,一个是用来处理的数值,比如损失等等,另一个是批量大小。
比如损失,假设批次为32,那么每个batch_size更新一次。
代码解释说明:
losses = AverageMeter()
loss_list = [0.5,0.4,0.5,0.6,1]
batch_size = 2
for los in loss_list:
losses.update(los,batch_size)
print(losses.avg)
本质上还是对所有batch_size的损失取平均。方便训练时输出每个batch的loss。
参考链接:https://blog.csdn.net/qq_39783265/article/details/105398427