Python 从 train_log.txt 中画 loss 曲线,多epoch多batch

本文介绍如何使用Python逐行解析txt文件中的训练记录,提取loss和epoch信息,并为每个epoch计算mean_loss。示例代码展示了如何从'./train_log.txt'中提取数据并绘制训练损失随 epoch 变化的图表。
摘要由CSDN通过智能技术生成

需要解决的问题:

1. 逐行读取 txt 文件中的训练记录

2. 提取 每行中的 loss、epoch 数据信息

3. 针对每个 epoch 的多个 batch 计算一个 mean_loss

train_log 中的数据信息和格式:

Python 代码

import re
import matplotlib.pyplot as plt
import os.path as osp

fullpath = osp.abspath('./train_log.txt')
filedir, filename = osp.split(fullpath)
count, x = 0, 0
Loss, epoch = [], {0}

with open(fullpath, 'r') as f:
    while True:
        line = f.readline()
        if line == '':
            break
        if not line.startswith('2021-08'):
            continue

        _, start_epoch = re.search('epoch: ', line, flags=0).span()
        end_epoch, _ = re.search(', batch:', line, flags=0).span()
        current_epoch = float(line[start_epoch:end_epoch])

        _, start_loss = re.search('train_loss: ', line, flags=0).span()
        end_loss, _ = re.search(', time:', line, flags=0).span()
        current_loss = float(line[start_loss:end_loss])

        if current_epoch in epoch:
            x += current_loss
            count += 1
        else:
            epoch.add(current_epoch)
            Loss.append(x/count)
            x = current_loss
            count = 1
    Loss.append(x / count)

plt.plot(list(epoch), Loss)
plt.xlabel('Epoch')
plt.ylabel('Training Loss')
pngName = filename.split('.')[0]
plt.savefig(osp.join(filedir, pngName))
plt.show()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值