需要解决的问题:
1. 逐行读取 txt 文件中的训练记录
2. 提取 每行中的 loss、epoch 数据信息
3. 针对每个 epoch 的多个 batch 计算一个 mean_loss
train_log 中的数据信息和格式:
Python 代码
import re
import matplotlib.pyplot as plt
import os.path as osp
fullpath = osp.abspath('./train_log.txt')
filedir, filename = osp.split(fullpath)
count, x = 0, 0
Loss, epoch = [], {0}
with open(fullpath, 'r') as f:
while True:
line = f.readline()
if line == '':
break
if not line.startswith('2021-08'):
continue
_, start_epoch = re.search('epoch: ', line, flags=0).span()
end_epoch, _ = re.search(', batch:', line, flags=0).span()
current_epoch = float(line[start_epoch:end_epoch])
_, start_loss = re.search('train_loss: ', line, flags=0).span()
end_loss, _ = re.search(', time:', line, flags=0).span()
current_loss = float(line[start_loss:end_loss])
if current_epoch in epoch:
x += current_loss
count += 1
else:
epoch.add(current_epoch)
Loss.append(x/count)
x = current_loss
count = 1
Loss.append(x / count)
plt.plot(list(epoch), Loss)
plt.xlabel('Epoch')
plt.ylabel('Training Loss')
pngName = filename.split('.')[0]
plt.savefig(osp.join(filedir, pngName))
plt.show()