深度学习训练工具包Tools——LossHistory损失记录工具

工具包Tools——LossHistory损失记录工具

"""
# -*- utf-8 coding -*-#
作者:沐枫
日期:2021年08月01日

使用类LossHistory保存训练时候的损失
使用示例:
loss_history = LossHistory('./loss) or loss_history = LossHistory('./loss, 'loss_2021_08_02_19_24_09)
...
train_loss = ...
...
val_loss = ...
loss_history.add_loss(train_loss, val_loss) or loss_history.add_loss(train_loss)
"""

import os
import datetime
import matplotlib.pyplot as plt
import scipy.signal


class LossHistory:
    def __init__(self, root, is_continue=None):
        """

        :param root: 根目录,主要是保存这个项目运行的所有loss结果的文件夹
        :param is_continue: 继续训练的根目录,是与否接着之前的训练,主要是为了继续上次被中断的训练继续训练
        """
        self.root = root
        self.__assert_dir(root)

        if is_continue is not None:
            self.str_time = self.analysis(is_continue)  # 解析目录名字,获得上次训练的时间字符串

            self.dir_name = 'loss_' + self.str_time  # 保存训练文件的文件夹名字
            self.loss_dir = os.path.join(self.root, self.dir_name)  # 路径
            assert os.path.exists(self.loss_dir), 'direction {} is not exist.'.format(self.loss_dir)

            self.train_loss_path = os.path.join(self.loss_dir, 'train.txt')  # 读出以前训练的数据
            self.val_loss_path = os.path.join(self.loss_dir, 'val.txt')

            self.train_loss_list = self.read_loss(self.train_loss_path)
            self.val_loss_list = self.read_loss(self.val_loss_path)

            # print(self.train_loss_list)
            # print(self.val_loss_list)

        else:
            current_time = datetime.datetime.now()
            str_time = datetime.datetime.strftime(current_time, '%Y_%m_%d_%H_%M_%S')
            self.str_time = str_time

            self.dir_name = 'loss_' + self.str_time
            self.loss_dir = os.path.join(self.root, self.dir_name)
            self.__assert_dir(self.loss_dir)

            self.train_loss_path = os.path.join(self.loss_dir, 'train.txt')
            self.val_loss_path = os.path.join(self.loss_dir, 'val.txt')

            self.train_loss_list = []
            self.val_loss_list = []

    @staticmethod
    def __assert_dir(dir_name):  # 判断文件夹是否存在
        if not os.path.exists(dir_name):
            os.makedirs(dir_name)

    @staticmethod
    def analysis(continue_path):  # 解析字符串获得字符串包含的时间信息
        list0 = continue_path.split('_')
        list0.pop(0)
        str_time = ''
        for ii, s in enumerate(list0):
            if ii != len(list0) - 1:
                str_time = str_time + s + '_'
            else:
                str_time = str_time + s
        return str_time

    def add_loss(self, train_loss, val_loss=None):  # 添加loss数据
        self.train_loss_list.append(train_loss)
        self.save_loss(self.train_loss_path, train_loss)

        if val_loss is None:
            self.val_loss_list.append(None)
            self.save_loss(self.val_loss_path, None)
        else:
            self.val_loss_list.append(val_loss)
            self.save_loss(self.val_loss_path, val_loss)

        self.plt_loss()

    @staticmethod
    def save_loss(save_path, loss):  # 保存loss数据
        with open(save_path, 'a', encoding='utf-8') as file:
            file.write(str(loss) + '\n')

    def read_loss(self, read_path):  # 读出loss数据
        if os.path.exists(read_path):
            with open(read_path, 'r', encoding='utf-8') as file:
                txt_list = [loss.strip() for loss in file.readlines()]
                if read_path is self.val_loss_path:
                    if 'None' in txt_list:  # 如果'None'在文件中那么就直接val数据全部置为None
                        loss_list = []
                        for _ in range(len(txt_list)):
                            loss_list.append(None)
                        return loss_list
                    else:
                        loss_list = [float(loss.strip()) for loss in txt_list]
                        return loss_list

                else:
                    loss_list = [float(loss.strip()) for loss in txt_list]
                    return loss_list

    def plt_loss(self):  # 绘制数据
        length = range(len(self.train_loss_list))

        plt.figure('loss')
        plt.plot(length, self.train_loss_list, 'red', linewidth=2, label='train loss')  # 在画布上绘制训练损失曲线

        if None not in self.val_loss_list:
            plt.plot(length, self.val_loss_list, 'coral', linewidth=2, label='val loss')  # 在画布上绘制验证损失曲线

        try:
            if len(self.train_loss_list) < 25:
                num = 5
            else:
                num = 15

            plt.plot(length, scipy.signal.savgol_filter(self.train_loss_list, num, 3), 'green', linestyle='--',
                     linewidth=2, label='smooth train loss')

            if None not in self.val_loss_list:
                plt.plot(length, scipy.signal.savgol_filter(self.val_loss_list, num, 3), '#8B4513', linestyle='--',
                     linewidth=2, label='smooth val loss')

        except:
            pass

        plt.grid(True)  # 是否打开画布网格
        # 设置x轴和y轴
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        # 图例的位置上,右上角
        plt.legend(loc="upper right")
        # 保存图片
        plt.savefig(os.path.join(self.loss_dir, "epoch_loss_" + str(self.str_time) + ".png"))
        plt.close('loss')


if __name__ == '__main__':
	# 使用示例
    loss_history = LossHistory('./loss')
    for i in [15, 45, 7, 12, 49, 66, 18, 45, 18, 65, 35, 57]:
        loss_history.add_loss(i / 2, i * 2 / 5)

2021.12.20更新,可以保存任意数量的损失数据

class LossHistoryN:
    """
    可以记录任意数目的损失
    """

    def __init__(self, root, loss_name: [list, tuple] = ('train',), path=None):
        """
        
        :param root: 保存数据的根目录
        :param loss_name: 需要记录的损失的名字
        :param path: 本次保存的路径
        """
        # assert 是满足条件就不打印东西
        assert len(loss_name) != 0, "Loss name shouldn't be empty."
        self.root = self.__check_dir(root)
        self.loss_name = loss_name

        if path is None:
            current_time = datetime.datetime.now()
            str_time = datetime.datetime.strftime(current_time, '%Y_%m_%d_%H_%M_%S')
            self.str_time = str_time
        else:
            self.str_time = path

        self.dir_name = r'loss_%s' % self.str_time
        self.loss_dir = self.__check_dir(os.path.join(self.root, self.dir_name))
        
        # 将数据列表和保存数据的路径的信息保存到字典中
        self.dictionary = {}  # 保存数据
        self.paths = {}  # 保存数据的TXT路径
        for name in loss_name:
            self.dictionary[name] = []
            self.paths[name] = os.path.join(self.loss_dir, '%s.txt' % name)

        self.length = 0
        self.colors = ['red', 'y', 'blue', 'k', 'green', 'm', 'black', 'c']
        self.linestyles = ['--', '-.', ':', '-']

    def add_loss(self, **kwargs):
        """
        
        :param kwargs: 一个参数字典,输入是需要字典的键值和初始化时self.loss_name中的一样才会正常运行
        :return: 
        """
        assert len(kwargs.keys()) == len(self.loss_name), 'kwargs name must equal loss name.'
        names = kwargs.keys()
        # 添加数据到字典中
        for name in names:
            assert name in self.loss_name, 'kwargs.keys must in self.loss_name.'
            loss_value = kwargs[name]
            self.dictionary[name].append(loss_value)
            self.save_loss(self.paths[name], loss_value)

        self.length += 1
        self.plt_loss()

    def plt_loss(self):
        length = range(self.length)
        plt.figure('loss')

        for index, name in enumerate(self.loss_name):
            # 在画布上绘制训练损失曲线
            plt.plot(length, self.dictionary[name], self.colors[index * 2], linewidth=2, label='%s loss' % name)

        try:
            if self.length < 25:
                num = 5
            else:
                num = 15

            for index, name in enumerate(self.loss_name):
                plt.plot(length, scipy.signal.savgol_filter(self.dictionary[name], num, 3), self.colors[index * 2 + 1],
                         linestyle=self.linestyles[index], linewidth=2, label='smooth %s loss' % name)

        except:
            pass

        plt.grid(True)  # 是否打开画布网格
        # 设置x轴和y轴
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        # 图例的位置上,右上角
        plt.legend(loc="upper right")
        # 保存图片
        plt.savefig(os.path.join(self.loss_dir, "epoch_loss_%s.png" % self.str_time))
        plt.close('loss')

    @staticmethod
    def save_loss(save_path, loss):
        with open(save_path, 'a', encoding='utf-8') as file:
            file.write(str(loss) + '\n')

    @staticmethod
    def __check_dir(_path):
        if not os.path.exists(_path):
            os.makedirs(_path)
        return _path


if __name__ == '__main__':
    loss_history = LossHistoryN('./loss', loss_name=('train', 'hot_map', 'w_h', 'offset'))
    for i in [15, 45, 7, 12, 49, 66, 18, 45, 18, 65, 35, 57]:
        loss_history.add_loss(train=i / 2, hot_map=i * 2 / 5, w_h=i * 4 / 3, offset=i * 100 / 555)
  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值