AverageMeter()的作用与用法

utils.py源码

from __future__ import division, absolute_import

__all__ = ['AverageMeter']


[docs]class AverageMeter(object):
    """Computes and stores the average and current value.

    Examples::
        >>> # Initialize a meter to record loss
        >>> losses = AverageMeter()
        >>> # Update meter after every minibatch update
        >>> losses.update(loss_value, batch_size)
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

在pytorch中用utils包来更新得分、损失等等,百度根本搜不到,不行就google!!!
代码说的例子是输入有两个参数,一个是用来处理的数值,比如损失等等,另一个是批量大小。
比如损失,假设批次为32,那么每个batch_size更新一次。

train函数是模型训练的入口。首先一些变量的更新采用自定义的AverageMeter类来管理,后面会介绍该类的定义。然后model.train()是设置为训练模式。 for i, (input, target) in enumerate(train_loader) 是数据迭代读取的循环函数,具体而言,当执行enumerate(train_loader)的时候,是先调用DataLoader类的__iter__方法,该方法里面再调用DataLoaderIter类的初始化操作__init__。而当执行for循环操作时,调用DataLoaderIter类的__next__方法,在该方法中通过self.collate_fn接口读取self.dataset数据时就会调用TSNDataSet类的__getitem__方法,从而完成数据的迭代读取。读取到数据后就将数据从Tensor转换成Variable格式,然后执行模型的前向计算:output = model(input_var),得到的output就是batch size*class维度的Variable;损失函数计算: loss = criterion(output, target_var);准确率计算: prec1, prec5 = accuracy(output.data, target, topk=(1,5));模型参数更新等等。其中loss.backward()是损失回传, optimizer.step()是模型参数更新。

在train函数中采用自定义的AverageMeter类来管理一些变量的更新。在初始化的时候就调用的重置方法reset。当调用该类对象的update方法的时候就会进行变量更新,当要读取某个变量的时候,可以通过对象.属性的方式来读取,比如在train函数中的top1.val读取top1准确率。
 

在train函数中采用自定义的AverageMeter类来管理一些变量的更新。在初始化的时候就调用的重置方法reset。当调用该类对象的update方法的时候就会进行变量更新,当要读取某个变量的时候,可以通过对象.属性的方式来读取,比如在train函数中的top1.val读取top1准确率。
 

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值