利用SSD目标检测算法训练我们的数据集时,将训练结果保存在log文件中。
我的log文件如下
然后利用python脚本文件将我们需要的数据提取出来,也就是iteration,loss,(或者时epoch,loss).我的脚本文件如下:
import matplotlib.pyplot as plt
import numpy as np
import pylab as pl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
line = []
with open(r"FYans.log", encoding='utf-8') as f: # 从log文件中读出数据
for line1 in f:
# print(line1)
line.append(line1)
# print(line)
file_handle=open('1.txt',mode='w')
for item in line:
# 判断每一行是否以Epoch为开头
if(item.split('||')[0].startswith('Epoch')):
# print(item)
datalist = []
#按照||进行分割。得到列表strl
strl = item.split('||')
# print(strl)
# 这两行时删除掉无用的数据,得到了['Epoch:7 ', ' epochiter: 460/511', ' loss: 8.1518']这样的列表
del strl[2:4]
del strl[3:] # ['Epoch:7 ', ' epochiter: 460/511', ' loss: 8.1518']
for i in range(3):
if i==0:
# 可能的得到的数据带有空格,去除掉
str4 = strl[i].split(":")[-1].strip()
# print(strl[i].split(":")[-1])
elif i==1:
str1 = strl[i].split(":")[-1].strip()
# print(strl)
str4 = str1.replace("/511","")
# del strl[-4:]
# print(str4)
else:
str4 = strl[i].split(":")[-1].strip()
# print(str4)
datalist.append(str4)
# print(datalist) # ['7 ', '460', '8.1518']
# 下面的计算是为了得到我所需要的iteration,每个人可能不一样,视情况自己修改,
i = 0
x = (int(datalist[i])-1)*511+(int(datalist[i + 1]))
y = float(datalist[i+2])
# 写入到文件中
file_handle.write(str(x) + " " +str(y)+'\n')
file_handle.close()
如果没问题,会得到如下1.txt文件:
我的数据第一行是iteration,第二行是Loss,
下面就是绘制曲线plot.py
import os
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import pylab as pl
# 得到txt文件中的数据
def getLossAcc(logFile):
f = open(logFile, "r", encoding='utf-8')
line = f.readline() # 以行的形式进行读取文件
iterate = []
loss = []
while line:
nameArr = line.split(" ")
iterate.append(int(nameArr[0]))
loss.append(float(nameArr[1]))
# print(accurayStr)
line = f.readline()
f.close()
return iterate, loss
# 绘制曲线
def drawLine(iterate, loss, xName, yName, title,graduate):
# 横坐标 采用列表表达式
x = iterate
# 纵坐标
y = loss
# 生成折线图:函数polt
plt.plot(x, y)
# 设置横坐标说明
plt.xlabel(xName)
# 设置纵坐标说明
plt.ylabel(yName)
# 添加标题
plt.title(title)
# 设置纵坐标刻度
plt.yticks(graduate)
# 显示网格
plt.grid(True)
# 显示图表
plt.show()
# 保存结果图
plt.savefig("train_results_loss.png")
if __name__ == '__main__':
iterate, loss = getLossAcc(r"1.txt")
#主函数,其中最后一项列表是纵坐标的刻度,
drawLine(iterate, loss, "Iterated", "Loss", "Loss function curve", [0, 5, 10, 15, 20, 25, 30])
得到如下曲线:
希望帮助需要的你,
有问题留言,这些代码均在我的是成功之后贴上的,只需根据自己的需要进行简单修改就可用。