Pytorch模型训练中 使用的 MetricLogger类总结

MetricLogger类


这个类主要用来打印输出训练的时候产生的一些数据
首先搬出我们看到的源代码,主要是在看何凯明大佬的MAE项目代码的时候遇到了,一起来学习一下~~
MAE-github官方源代码:https://github.com/facebookresearch/mae

源代码

class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value)


class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if v is None:
                continue
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, attr))

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(
                "{}: {}".format(name, str(meter))
            )
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        log_msg = [
            header,
            '[{0' + space_fmt + '}/{1}]',
            'eta: {eta}',
            '{meters}',
            'time: {time}',
            'data: {data}'
        ]
        if torch.cuda.is_available():
            log_msg.append('max mem: {memory:.0f}')
        log_msg = self.delimiter.join(log_msg)
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB))
                else:
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)))
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('{} Total time: {} ({:.4f} s / it)'.format(
            header, total_time_str, total_time / len(iterable)))

首先我们来看一下SmoothValue这个类干了啥:
首先,它的属性里面有
属性

self.deque利用队列来获取的数值
self.total记录累计的数值的总和
self.count记录所由累计的个数的总和
self.fmtfmt = “{median:.4f} ({global_avg:.4f})”

方法
update:更新数值

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

synchronize_between_processes:同步进程数值
dist.barrier():阻塞进程,等待所有进程完成计算
dist.all_reduce():把所有节点上计算好的数值进行累加,然后传递给所有的节点。

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

然后就是一些median,avg,max,global_avg,value方法,比较简单。

MetricLogger类就两个属性,一个是self.meters,另一个是self.delimiter
self.meters = defaultdict(SmoothedValue):这样meters的值即value可以使用SmoothedValue类的属性和方法(python 一切皆对象的思想)
self.delimiter = delimeter:是一个字符串类型
self.meters里面是是一个字典:{name:meter}meter是字典

主要信息存贮在log_msg的列表里面了

log_msg = [
            header,
            '[{0' + space_fmt + '}/{1}]',
            'eta: {eta}',
            '{meters}',
            'time: {time}',
            'data: {data}'
        ]
if torch.cuda.is_available():
        log_msg.append('max mem: {memory:.0f}')
        log_msg = self.delimiter.join(log_msg)


##打印信息
if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB))
                else:
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)))
  • 15
    点赞
  • 46
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
PyTorch是目前深度学习领域最受欢迎的开源框架之一。该框架提供了丰富的功能,包括构建计算图的灵活性、对GPU的支持,以及易于调试和可视化的接口。 PyTorch模型训练步骤与其他深度学习框架似,但也有其独特之处。以下是一些PyTorch模型训练实用教程: 1. 准备数据: PyTorch提供了一些实用的来创建和加载数据集。您可以使用DataLoader来创建批量数据并进行数据预处理。还可以使用transform将数据转换为需要的格式。 2. 构建模型使用PyTorch构建模型非常容易。您只需定义模型的结构和构造函数即可。PyTorch支持多种模型型,包括卷积神经网络、循环神经网络和转移学习。 3. 定义损失函数: 损失函数是模型最关键的部分之一。PyTorch提供多种用于分、回归和聚的损失函数。您还可以创建自定义损失函数。 4. 优化算法: 优化算法是用于更新模型参数的方法。PyTorch支持多种优化算法,包括随机梯度下降、Adam和Adagrad。此外,可以通过定义自己的优化算法来实现个性化的优化。 5. 训练模型训练模型使用深度学习时最耗时的部分之一。在PyTorch,您可以使用for循环迭代训练数据,并使用backward()函数进行反向传播。还可以使用scheduler动态地调整学习率。 6. 评估模型: 评估模型是确保模型工作正常的必要步骤之一。您可以使用PyTorch提供的来计算模型的准确性、F1分数等指标。 总体来说,PyTorch对于初学者和专业人士来说都是一种极具吸引力的深度学习框架。通过了解PyTorch的基本功能,您可以更好地了解如何使用它来训练自己的模型

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值