深度学习训练测试中日志记录方法合集

一、tensorboard

from tensorboardX import SummaryWriter

...

# 一个写对象就对应着一个event
train_log_path = '.'       # 在当前工作目录下保存,当前工作目录即os.getcwd()所得到的目录
valid_log_path = '.'
train_writer = SummaryWriter(train_log_path, filename_suffix='TRAIN')
val_writer = SummaryWriter(valid_log_path, filename_suffix='VAL')

...

for epoch in range(0, max_epochs):
    
    ...

    for batch_idx, batch in enumerate(train_dataloader):
        ...
        loss1=...
        loss2=...
        n_batchsize = len(train_dataloader)
        step = epoch * n_batchsize + batch_idx
        train_writer.add_scalar('Loss/Step/loss1', loss1, step)
        train_writer.add_scalar('Loss/Step/loss2', loss2, step)
    
    # 计算一个epoch下来平均的指标值
    avg_loss1=total_loss1/n_batchsize
    avg_loss2=total_loss2/n_batchsize
    train_writer.add_scalar('Loss/Epoch/loss1', avg_loss1, epoch)
    train_writer.add_scalar('Loss/Epoch/loss2', avg_loss2, epoch)

train_writer.close()
# tensorboard可视化命令 
tensorboard --logdir="/your/events_path"

二、logging


import logging
import torch.distributed as dist
import os

def get_logger(log_name, log_level, log_file=None, file_mode='a'):
    '''
    获取由logging库提供的logger
    log_name:用于标识log的名称
    log_level:打印等级,例如logging.INFO  logging.DEBUG  logging.ERROR  logging.CRITICAL
    log_file:输出日志的文件地址
    file_mode:输出日志的文件模式,a为追加,w为覆盖等
    '''

    logger = logging.getLogger(log_name)
    logger.propagate = False                    # 阻止日志消息传递给父级logger

    # 判断是否为是多卡运行
    if dist.is_available() and dist.is_initialized():
        rank = dist.get_rank()
    else:
        rank = 0

    handlers = []

    # 流处理器
    stream_handler = logging.StreamHandler()    # 用于将日志消息输出到控制台或者标准输出流
    handlers.append(stream_handler)

    if rank == 0 and log_file is not None:
        # 文件处理器
        if not os.path.exists(os.path.dirname(log_file)):
            os.makedirs(os.path.dirname(log_file))
        file_handler = logging.FileHandler(log_file, file_mode) # file_mode为'a'则追加,为'w'则覆盖
        handlers.append(file_handler)

    # 格式化器
    plain_formatter = logging.Formatter(
        "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
    )

    formatter = plain_formatter

    # 处理器加格式
    for handler in handlers:
        handler.setFormatter(formatter)
        handler.setLevel(log_level)
        logger.addHandler(handler)

    if rank == 0:
        logger.setLevel(log_level)
    else:
        logger.setLevel(logging.ERROR)
    
    return logger


if __name__ == "__main__":
    log_name = "Debug"
    log_level = logging.INFO            # DEBUG INFO WARNING ERROR CRITICAL
    logging.DEBUG
    logging.ERROR
    logging.CRITICAL
    file_mode = 'w'
    log_file = './log/train.log'

    logger = get_logger(log_name, log_level, log_file, file_mode='w')
    logger.info("=> Loading config ...")
    logger.info("=> Start train!")
    pass
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值