torchnet.meter使用记录

训练过程中需要使用变量的平均值,如计算loss,一个批次内指标等。可以自己写一个类或者使用torchnet包。

1、自定义类实现平均

import torch
import numpy as np

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name):
        self.name = name
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        format_str = (f"%s: avg=%s, val=%s, count=%s") % (
            self.name, str(self.avg), str(self.val), str(self.count))
        return format_str

调用方式:

1)单值调用

losses = AverageMeter("Loss")
for i in range(10):
    losses.update(i + 5, 1)
# print(losses.val, losses.avg, losses.count)
print(losses)

输出,其中avg为均值,val为当前值,count为个数,结果如下:

2)多值调用

losses = AverageMeter("Loss")
losses.reset()
for i in range(10):
    losses.update(np.array([1, 3 + i]), 1)
# print(str(losses.val), losses.avg, losses.count)
print(losses)

2、torchnet.meter包使用

包安装:pip install torchnet

2.1、AverageValueMeter

函数作用:添加单值数据,进行取平均值及标准差计算。

from torchnet import meter

loss_meter = meter.AverageValueMeter()
loss_meter.reset()
for i in range(10):
    loss_meter.add(i)
print(loss_meter.value()) # mean, std

输出:

使用重置(清空序列):loss_meter.reset()

2.2、ConfusionMeter

混淆矩阵。待整理...

已标记关键词 清除标记
相关推荐
©️2020 CSDN 皮肤主题: 像素格子 设计师:CSDN官方博客 返回首页