跑了个网络,每个epoch的损失等信息会打印在日志里,但是损失变化通过数字来看不够直观。
那个日志的格式是这样的:
import numpy as np
from os import path
import os
import matplotlib.pyplot as plt
from matplotlib import cm,ticker
if __name__ == '__main__':
path="tt/training_log_2021_3_31_09_50_38.txt"
out = open(path, encoding='utf-8')
lines = out.readlines()
#提取trainLoss和validationLoss
trainLoss=[]
validationLoss=[]
for line in lines:
if "train loss" in line:
val=np.float(line.split("loss : ")[-1][:-1]) #[:-1]是去除末尾'\n'
trainLoss.append(val)
if "validation loss" in line:
val=np.float(line.split("loss: ")[-1][:-1])
validationLoss.append(val)
epochNum=len(trainLoss)
for i in range(epochNum):
print("epoch{}: train loss:{} val loss:{}".format(i,trainLoss[i],validationLoss[i]))
#绘图
fig=plt.figure()
xs=np.arange(epochNum)
plt.yticks(np.arange(-1,0,0.1))
plt.plot(xs, trainLoss, color='coral', label="train loss")
plt.plot(xs, validationLoss, color='g', label="val loss")
plt.legend()
plt.show()
#plt.savefig("loss.png")
得到图示:
OVER