工具包Tools——LossHistory损失记录工具
"""
# -*- utf-8 coding -*-#
作者:沐枫
日期:2021年08月01日
使用类LossHistory保存训练时候的损失
使用示例:
loss_history = LossHistory('./loss) or loss_history = LossHistory('./loss, 'loss_2021_08_02_19_24_09)
...
train_loss = ...
...
val_loss = ...
loss_history.add_loss(train_loss, val_loss) or loss_history.add_loss(train_loss)
"""
import os
import datetime
import matplotlib.pyplot as plt
import scipy.signal
class LossHistory:
def __init__(self, root, is_continue=None):
"""
:param root: 根目录,主要是保存这个项目运行的所有loss结果的文件夹
:param is_continue: 继续训练的根目录,是与否接着之前的训练,主要是为了继续上次被中断的训练继续训练
"""
self.root = root
self.__assert_dir(root)
if is_continue is not None:
self.str_time = self.analysis(is_continue)
self.dir_name = 'loss_' + self.str_time
self.loss_dir = os.path.join(self.root, self.dir_name)
assert os.path.exists(self.loss_dir), 'direction {} is not exist.'.format(self.loss_dir)
self.train_loss_path = os.path.join(self.loss_dir, 'train.txt')
self.val_loss_path = os.path.join(self.loss_dir, 'val.txt')
self.train_loss_list = self.read_loss(self.train_loss_path)
self.val_loss_list = self.read_loss(self.val_loss_path)
else:
current_time = datetime.datetime.now()
str_time = datetime.datetime.strftime(current_time, '%Y_%m_%d_%H_%M_%S')
self.str_time = str_time
self.dir_name = 'loss_' + self.str_time
self.loss_dir = os.path.join(self.root, self.dir_name)
self.__assert_dir(self.loss_dir)
self.train_loss_path = os.path.join(self.loss_dir, 'train.txt')
self.val_loss_path = os.path.join(self.loss_dir, 'val.txt')
self.train_loss_list = []
self.val_loss_list = []
@staticmethod
def __assert_dir(dir_name):
if not os.path.exists(dir_name):
os.makedirs(dir_name)
@staticmethod
def analysis(continue_path):
list0 = continue_path.split('_')
list0.pop(0)
str_time = ''
for ii, s in enumerate(list0):
if ii != len(list0) - 1:
str_time = str_time + s + '_'
else:
str_time = str_time + s
return str_time
def add_loss(self, train_loss, val_loss=None):
self.train_loss_list.append(train_loss)
self.save_loss(self.train_loss_path, train_loss)
if val_loss is None:
self.val_loss_list.append(None)
self.save_loss(self.val_loss_path, None)
else:
self.val_loss_list.append(val_loss)
self.save_loss(self.val_loss_path, val_loss)
self.plt_loss()
@staticmethod
def save_loss(save_path, loss):
with open(save_path, 'a', encoding='utf-8') as file:
file.write(str(loss) + '\n')
def read_loss(self, read_path):
if os.path.exists(read_path):
with open(read_path, 'r', encoding='utf-8') as file:
txt_list = [loss.strip() for loss in file.readlines()]
if read_path is self.val_loss_path:
if 'None' in txt_list:
loss_list = []
for _ in range(len(txt_list)):
loss_list.append(None)
return loss_list
else:
loss_list = [float(loss.strip()) for loss in txt_list]
return loss_list
else:
loss_list = [float(loss.strip()) for loss in txt_list]
return loss_list
def plt_loss(self):
length = range(len(self.train_loss_list))
plt.figure('loss')
plt.plot(length, self.train_loss_list, 'red', linewidth=2, label='train loss')
if None not in self.val_loss_list:
plt.plot(length, self.val_loss_list, 'coral', linewidth=2, label='val loss')
try:
if len(self.train_loss_list) < 25:
num = 5
else:
num = 15
plt.plot(length, scipy.signal.savgol_filter(self.train_loss_list, num, 3), 'green', linestyle='--',
linewidth=2, label='smooth train loss')
if None not in self.val_loss_list:
plt.plot(length, scipy.signal.savgol_filter(self.val_loss_list, num, 3), '#8B4513', linestyle='--',
linewidth=2, label='smooth val loss')
except:
pass
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.loss_dir, "epoch_loss_" + str(self.str_time) + ".png"))
plt.close('loss')
if __name__ == '__main__':
loss_history = LossHistory('./loss')
for i in [15, 45, 7, 12, 49, 66, 18, 45, 18, 65, 35, 57]:
loss_history.add_loss(i / 2, i * 2 / 5)
2021.12.20更新,可以保存任意数量的损失数据
class LossHistoryN:
"""
可以记录任意数目的损失
"""
def __init__(self, root, loss_name: [list, tuple] = ('train',), path=None):
"""
:param root: 保存数据的根目录
:param loss_name: 需要记录的损失的名字
:param path: 本次保存的路径
"""
assert len(loss_name) != 0, "Loss name shouldn't be empty."
self.root = self.__check_dir(root)
self.loss_name = loss_name
if path is None:
current_time = datetime.datetime.now()
str_time = datetime.datetime.strftime(current_time, '%Y_%m_%d_%H_%M_%S')
self.str_time = str_time
else:
self.str_time = path
self.dir_name = r'loss_%s' % self.str_time
self.loss_dir = self.__check_dir(os.path.join(self.root, self.dir_name))
self.dictionary = {}
self.paths = {}
for name in loss_name:
self.dictionary[name] = []
self.paths[name] = os.path.join(self.loss_dir, '%s.txt' % name)
self.length = 0
self.colors = ['red', 'y', 'blue', 'k', 'green', 'm', 'black', 'c']
self.linestyles = ['--', '-.', ':', '-']
def add_loss(self, **kwargs):
"""
:param kwargs: 一个参数字典,输入是需要字典的键值和初始化时self.loss_name中的一样才会正常运行
:return:
"""
assert len(kwargs.keys()) == len(self.loss_name), 'kwargs name must equal loss name.'
names = kwargs.keys()
for name in names:
assert name in self.loss_name, 'kwargs.keys must in self.loss_name.'
loss_value = kwargs[name]
self.dictionary[name].append(loss_value)
self.save_loss(self.paths[name], loss_value)
self.length += 1
self.plt_loss()
def plt_loss(self):
length = range(self.length)
plt.figure('loss')
for index, name in enumerate(self.loss_name):
plt.plot(length, self.dictionary[name], self.colors[index * 2], linewidth=2, label='%s loss' % name)
try:
if self.length < 25:
num = 5
else:
num = 15
for index, name in enumerate(self.loss_name):
plt.plot(length, scipy.signal.savgol_filter(self.dictionary[name], num, 3), self.colors[index * 2 + 1],
linestyle=self.linestyles[index], linewidth=2, label='smooth %s loss' % name)
except:
pass
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.loss_dir, "epoch_loss_%s.png" % self.str_time))
plt.close('loss')
@staticmethod
def save_loss(save_path, loss):
with open(save_path, 'a', encoding='utf-8') as file:
file.write(str(loss) + '\n')
@staticmethod
def __check_dir(_path):
if not os.path.exists(_path):
os.makedirs(_path)
return _path
if __name__ == '__main__':
loss_history = LossHistoryN('./loss', loss_name=('train', 'hot_map', 'w_h', 'offset'))
for i in [15, 45, 7, 12, 49, 66, 18, 45, 18, 65, 35, 57]:
loss_history.add_loss(train=i / 2, hot_map=i * 2 / 5, w_h=i * 4 / 3, offset=i * 100 / 555)