深度学习训练工具包Tools——LearningRateHistory学习率记录工具

工具包Tools——LearningRateHistory学习率记录

import datetime
import os

import matplotlib.pyplot as plt
import scipy.signal


class LearningRateHistory:
    def __init__(self, root):
        """
        :param root: 根目录,主要是保存这个项目运行的所有结果的文件夹
        """
        self.root = root
        self.__assert_dir(root)

        current_time = datetime.datetime.now()
        str_time = datetime.datetime.strftime(current_time, '%Y_%m_%d_%H_%M_%S')
        self.str_time = str_time

        self.dir_name = 'lr_' + self.str_time
        self.dirs = os.path.join(self.root, self.dir_name)
        self.__assert_dir(self.dirs)

        self.learning_rate_path = os.path.join(self.dirs, 'learning_rate.txt')

        self.learning_rate_list = []

    @staticmethod
    def __assert_dir(dir_name):  # 判断文件夹是否存在
        if not os.path.exists(dir_name):
            os.makedirs(dir_name)

    def add_data(self, data):  # 实例化后直接调用这个函数就行,添加loss数据
        self.learning_rate_list.append(data)
        self.save_data(self.learning_rate_path, data)

        self.plt_picture()

    @staticmethod
    def save_data(save_path, data):  # 保存数据
        with open(save_path, 'a', encoding='utf-8') as file:
            file.write(str(data) + '\n')

    def plt_picture(self):  # 绘制数据
        length = range(len(self.learning_rate_list))

        plt.figure('lr')
        plt.plot(length, self.learning_rate_list, 'red', linewidth=2, label='learning rate')  # 在画布上绘制曲线

        try:
            if len(self.learning_rate_list) < 25:
                num = 5
            else:
                num = 15

            plt.plot(length, scipy.signal.savgol_filter(self.learning_rate_list, num, 3), 'green', linestyle='--',
                     linewidth=2, label='smooth learning rate')
        except:
            pass

        plt.grid(True)  # 是否打开画布网格
        # 设置x轴和y轴
        plt.xlabel('Epoch')
        plt.ylabel('lr')
        # 图例的位置上,右上角
        plt.legend(loc="upper right")
        # 保存图片
        plt.savefig(os.path.join(self.dirs, "epoch_loss_" + str(self.str_time) + ".png"))
        plt.close('lr')

if __name__ == '__main__':
    # 使用例子
    lr = LearningRateHistory(root='/lr')
    lr.add_data(data=0.0005)
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值