WinClip非官方复现代码学习笔记11

一、程序结构展示

上篇笔记着重讲了metric.py文件的内容,今天主要讲training_utils.py文件的内容。

二、代码功能介绍

  1. get_optimizer_from_args(model, lr, weight_decay, **kwargs): 这个函数用于根据给定的模型、学习率和权重衰减参数创建优化器。它使用 AdamW 优化器,并将模型中所有需要梯度更新的参数传递给优化器。

  2. get_lr_schedule(optimizer): 这个函数用于根据给定的优化器创建学习率调度器。它使用指数衰减方法,每个周期将学习率乘以 0.95。

  3. setup_seed(seed): 这个函数用于设置随机数种子,以确保实验的可复现性。它将随机数种子应用于 PyTorch、NumPy 和 Python 的随机数生成器,同时设置 CUDA 后端以确保在 GPU 上的随机性也是确定的。

  4. get_dir_from_args(root_dir, class_name, **kwargs): 这个函数用于根据给定的根目录、类名和其他参数创建用于保存模型、图片、日志等文件的目录。它根据传入的参数构建实验名称,并创建与实验相关的目录结构。然后,它将记录实验的日志文件路径,并使用 Loguru 日志记录器开始记录日志。

三、代码逐行注释

import random  # 导入 Python 的内置模块 random,用于生成伪随机数
import shutil  # 导入 Python 的内置模块 shutil,用于高级文件操作,如复制、移动、删除文件等
import time  # 导入 Python 的内置模块 time,用于处理时间相关的功能,如获取当前时间、计时等
import torch  # 导入 PyTorch 深度学习框架的主要模块
from torch.utils.tensorboard import SummaryWriter  # 从 PyTorch 的 torch.utils.tensorboard 模块中导入 SummaryWriter 类,用于将数据写入 TensorBoard 日志文件,用于可视化训练过程和结果

from utils.visualization import *  # : 从自定义的 utils.visualization 模块中导入所有内容。通常情况下,这个模块可能包含了一些用于数据可视化的辅助函数或工具
from loguru import logger  # 从 Loguru 日志库中导入 logger 对象。
# Loguru 是一个功能强大且易于使用的日志记录库,提供了丰富的日志记录功能,可以方便地在应用程序中记录日志消息,并支持灵活的日志配置和输出格式化

def get_optimizer_from_args(model, lr, weight_decay, **kwargs) -> torch.optim.Optimizer: # 这个函数用于根据给定的参数创建一个优化器对象 model: 要优化的模型对象。# lr: 学习率。# weight_decay: 权重衰减参数。# **kwargs: 其他参数(可选)
    return torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr,  # 首先使用 filter 函数过滤出模型中需要梯度更新的参数(即 requires_grad=True 的参数)
                             weight_decay=weight_decay)  # 使用 torch.optim.AdamW 类创建一个 AdamW 优化器对象,并将过滤后的参数和学习率、权重衰减参数传递给优化器对象


def get_lr_schedule(optimizer):  # 这个函数用于创建一个学习率调度器对象,用于动态调整优化器中的学习率。接受一个优化器对象作为参数.
    return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    # 使用 torch.optim.lr_scheduler.ExponentialLR 类创建一个指数衰减的学习率调度器对象。在本例中,指数衰减的因子 gamma 设置为 0.95。函数返回创建的学习率调度器对象


def setup_seed(seed):  # 这个函数用于设置随机种子,以确保实验的可重复性。它接受一个整数参数 seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True  # 将 PyTorch 的 CuDNN 后端设置为确定性模式,以确保在使用 CUDA 加速时的结果一致性
    # 使用 torch.manual_seed、torch.cuda.manual_seed_all、np.random.seed 和 random.seed 分别设置 PyTorch、CUDA、NumPy 和 Python 标准库的随机种子

def get_dir_from_args(root_dir, class_name, **kwargs):  # 根据输入参数创建并返回一些目录路径,并将日志输出到文件中。

    exp_name = f"{kwargs['dataset']}-k-{kwargs['k_shot']}"  # 根据函数参数 kwargs 中的 dataset 和 k_shot 字段的值构建了一个实验名称 exp_name

    csv_dir = os.path.join(root_dir, 'csv')  # 使用 os.path.join 函数将根目录 root_dir 和子目录名称 'csv' 拼接成一个完整的 CSV 文件目录路径 csv_dir
    csv_path = os.path.join(csv_dir, f"{exp_name}-indx-{kwargs['experiment_indx']}.csv")  # 根据实验名称、实验索引和文件后缀构建了一个 CSV 文件路径 csv_path

    model_dir = os.path.join(root_dir, exp_name, 'models')  # 这两行根据根目录、实验名称和子目录名称构建了模型文件夹路径 model_dir 和图片文件夹路径 img_dir
    img_dir = os.path.join(root_dir, exp_name, 'imgs')

    logger_dir = os.path.join(root_dir, exp_name, 'logger', class_name)  # 构建了日志文件夹路径 logger_dir,其中包含了根目录、实验名称、子目录名称 'logger' 和类名 class_name

    log_file_name = os.path.join(logger_dir,
                                 f'log_{time.strftime("%Y-%m-%d-%H-%I-%S", time.localtime(time.time()))}.log')
    # 这一行根据当前时间构建了日志文件名 log_file_name,其中包含了日志文件夹路径、日志文件名前缀 'log_'、日期时间信息和文件后缀 '.log'

    model_name = f'{class_name}'
    # 将输入的类名赋值给 model_name 变量

    os.makedirs(model_dir, exist_ok=True)  # 这几行分别创建了模型文件夹、图片文件夹、日志文件夹和 CSV 文件夹。如果文件夹已存在,则不会抛出异常,这得益于 exist_ok=True 参数
    os.makedirs(img_dir, exist_ok=True)
    os.makedirs(logger_dir, exist_ok=True)
    os.makedirs(csv_dir, exist_ok=True)

    logger.start(log_file_name)  # 启动了日志记录器,并将日志输出到指定的日志文件中

    logger.info(f"===> Root dir for this experiment: {logger_dir}")  # 通过日志记录器输出了实验的根目录路径

    return model_dir, img_dir, logger_dir, model_name, csv_path

  • 7
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值