《KnowPrompt》论文代码复现5-main.py代码讲解(超级详细!)

先附上代码,代码注释中会有一点讲解,详细的讲解在代码下面

"""Experiment-running framework."""
import argparse
import importlib
from logging import debug

import numpy as np
from pytorch_lightning.trainer import training_tricks
import torch
import pytorch_lightning as pl
import lit_models
import yaml
import time
from lit_models import TransformerLitModelTwoSteps
from transformers import AutoConfig, AutoModel
from pytorch_lightning.plugins import DDPPlugin
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # 1


# In order to ensure reproducible experiments, we must set random seeds.


def _import_class(module_and_class_name: str) -> type: # 2
    """Import class from a module, e.g. 'text_recognizer.models.MLP'"""
    module_name, class_name = module_and_class_name.rsplit(".", 1) # 3
    module = importlib.import_module(module_name) # 导入module_name代表的模块
    class_ = getattr(module, class_name) # 从moudle中得到class_name代表的类
    return class_


def _setup_parser(): # 4
    """Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
    parser = argparse.ArgumentParser(add_help=False) # 5 创建命令行参数解析器

    # Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
    trainer_parser = pl.Trainer.add_argparse_args(parser) # 6 将pl.Trainer的参数添加进parser
    trainer_parser._action_groups[1].title = "Trainer Args"  # 7 将trainer_parser的第二组参数的参数组名称改为Trainer Args
    parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser]) # 创建新解析器parser,这个解析器继承了trainer_parser的参数

    # Basic arguments
    parser.add_argument("--wandb", action="store_true", default=False) # 8
    parser.add_argument("--litmodel_class", type=str, default="BertLitModel") # 9
    parser.add_argument("--seed", type=int, default=7) # 10
    parser.add_argument("--data_class", type=str, default="WIKI80")
    parser.add_argument("--lr_2", type=float, default=2e-5)
    parser.add_argument("--model_class", type=str, default="RobertaForPrompt") # 11
    parser.add_argument("--two_steps", default=True, action="store_true") # 12
    parser.add_argument("--load_checkpoint", type=str, default=None) # 13
    # add
    parser.add_argument("--freeze", default=True) # 14

    parser.add_argument("--num_masks", type=int, default=2) # 15

    # Get the data and model classes, so that we can add their specific arguments
    temp_args, _ = parser.parse_known_args() # 16 temp_args得到命令行参数解析器中已知参数的值
    data_class = _import_class(f"data.{temp_args.data_class}") # 参数 data.WIKI80
    model_class = _import_class(f"models.{temp_args.model_class}") # 参数 models.RobertaForPrompt
    litmodel_class = _import_class(f"lit_models.{temp_args.litmodel_class}") # 参数 lit_models.BertLitModel

    # Get data, model, and LitModel specific arguments
    data_group = parser.add_argument_group("Data Args") # 17 创建参数组data_group
    data_class.add_to_argparse(data_group) # 把data_class添加进data_group

    model_group = parser.add_argument_group("Model Args")
    model_class.add_to_argparse(model_group)

    lit_model_group = parser.add_argument_group("LitModel Args")
    litmodel_class.add_to_argparse(lit_model_group)

    parser.add_argument("--help", "-h", action="help") # 18
    return parser


device = "cuda"
from tqdm import tqdm
def _get_relation_embedding(data):
    train_dataloader = data.train_dataloader() # 19
    #! hard coded
    relation_embedding = [[] for _ in range(36)] # 20
    model = AutoModel.from_pretrained('bert-base-uncased') # 21
    model.eval()
    model = model.to(device)


    cnt = 0
    for batch in tqdm(train_dataloader): # 22
        with torch.no_grad():
            #! why the sample in this case will cause errors
            if cnt == 416:
                continue
            cnt += 1
            input_ids, attention_mask, token_type_ids , labels = batch
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.to(device)

            logits = model(input_ids=input_ids, attention_mask=attention_mask,
                           token_type_ids=token_type_ids).last_hidden_state.detach().cpu() # 23
            _, mask_idx = (input_ids == 103).nonzero(as_tuple=True) # 24
            bs = input_ids.shape[0]
            mask_output = logits[torch.arange(bs), mask_idx] # [batch_size, hidden_size] # 25
            

            labels = labels.detach().cpu()
            mask_output = mask_output.detach().cpu()
            assert len(labels[0]) == len(relation_embedding) # 26
            for batch_idx, label in enumerate(labels.tolist()): # 27
                for i, x in enumerate(label):
                    if x:
                        relation_embedding[i].append(mask_output[batch_idx])
    
    # get the mean pooling
    for i in range(36):
        if len(relation_embedding[i]):
            relation_embedding[i] = torch.mean(torch.stack(relation_embedding[i]), dim=0) # 28
        else:
            relation_embedding[i] = torch.rand_like(relation_embedding[i-1]) # 29

    del model
    return relation_embedding


def main():
    parser = _setup_parser()
    args = parser.parse_args() # 30

    np.random.seed(args.seed) # 设置Numpy库中随机数生成器的种子
    torch.manual_seed(args.seed) # 设置PyTorch库中随机数生成器的种子
    pl.seed_everything(args.seed) # 设置PyTorch Lighting库中随机数生成器的种子
    data_class = _import_class(f"data.{args.data_class}") # 从data目录下的dialogue.py文件中引入WIKI80
    model_class = _import_class(f"models.{args.model_class}") # 从models目录下的roberta目录下的_init_.py文件中引入RobertaForPrompt
    litmodel_class = _import_class(f"lit_models.{args.litmodel_class}") # lit_models目录下

    config = AutoConfig.from_pretrained(args.model_name_or_path) # 31
    model = model_class.from_pretrained(args.model_name_or_path, config=config) # 加载模型,配置是上面的代码引入的配置
    data = data_class(args, model) # 32 按照命名空间中的要求加载数据集,用模型处理数据集
    data_config = data.get_data_config() # 33 获取数据集的配置信息
    model.resize_token_embeddings(len(data.tokenizer)) # 34 按照分词器词汇表大小调整嵌入层的大小

    # BertLitModel
    lit_model = litmodel_class(args=args, model=model, tokenizer=data.tokenizer) # 加载BertLitModel
    data.tokenizer.save_pretrained('test') # 34(2) 保存分词器到本地目录test
    # print('lit_model',lit_model)

    logger = pl.loggers.TensorBoardLogger("training/logs") # 35
    dataset_name = args.data_dir.split("/")[-1] # 36
    if args.wandb:
        logger = pl.loggers.WandbLogger(project="dialogue_pl", name=f"{dataset_name}") # 37
        logger.log_hyperparams(vars(args)) # 38
    
    # init callbacks
    early_callback = pl.callbacks.EarlyStopping(monitor="Eval/f1", mode="max", patience=5,check_on_train_epoch_end=False) # 39
    model_checkpoint = pl.callbacks.ModelCheckpoint(monitor="Eval/f1", mode="max", filename='{epoch}-{Eval/f1:.2f}', dirpath="output", save_weights_only=True) # 40
    callbacks = [early_callback, model_checkpoint]

    # args.weights_summary = "full"  # Print full summary of the model
    gpu_count = torch.cuda.device_count()
    accelerator = "ddp" if gpu_count > 1 else None # 41
    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=callbacks,
                                            logger=logger,
                                            default_root_dir="training/logs",
                                            gpus=gpu_count,
                                            accelerator=accelerator,
                                            plugins=DDPPlugin(find_unused_parameters=False) if gpu_count > 1 else None,) #42
    # trainer.tune(lit_model, datamodule=data)  # If passing --auto_lr_find, this will set learning rate
    trainer.fit(lit_model, datamodule=data) # 43

    # After training finishes, use best_model_path to retrieve(检索) the path to the best checkpoint file and best_model_score to retrieve its score.
    path = model_checkpoint.best_model_path
    print(f"best model save path {path}")
    if not os.path.exists("config"):
        os.mkdir("config")
    config_file_name = time.strftime("%H:%M:%S", time.localtime()) + ".yaml" # 44
    day_name = time.strftime("%Y-%m-%d")
    if not os.path.exists(os.path.join("config", day_name)):
        os.mkdir(os.path.join("config", time.strftime("%Y-%m-%d")))
    config = vars(args)
    config["path"] = path # 保存
    with open(os.path.join(os.path.join("config", day_name), config_file_name), "w") as file: # 最后这个才是整个文件名
        file.write(yaml.dump(config))

    # one step
    if not args.two_steps:
        trainer.test() # 45

    # 从第一阶段过渡到第二阶段的提示 # 46
    print('args.freeze:', args.freeze)
    args.freeze = False
    print('args.freeze:', args.freeze)

    step2_model_checkpoint = pl.callbacks.ModelCheckpoint(monitor="Eval/f1", mode="max", filename='{epoch}-{Step2Eval/f1:.2f}', dirpath="output", save_weights_only=True)
    if args.two_steps:
        # we build another trainer and model for the second training
        # use the Step2Eval/f1

        lit_model_second = TransformerLitModelTwoSteps(args=args, model=lit_model.model, tokenizer=data.tokenizer)
        lit_model_second.load_state_dict(torch.load(path)["state_dict"]) # 47

        step_early_callback = pl.callbacks.EarlyStopping(monitor="Eval/f1", mode="max", patience=6, check_on_train_epoch_end=False)
        callbacks = [step_early_callback, step2_model_checkpoint]
        trainer_2 = pl.Trainer.from_argparse_args(args,
                                                  callbacks=callbacks,
                                                  logger=logger,
                                                  default_root_dir="training/logs",
                                                  gpus=gpu_count,
                                                  accelerator=accelerator,
                                                  plugins=DDPPlugin(find_unused_parameters=False) if gpu_count > 1 else None,)
        trainer_2.fit(lit_model_second, datamodule=data)
        trainer_2.test()

main()

2023.12.29

1、os.environ["TOKENIZERS_PARALLELISM"] = "false"

这句代码是禁止分词器tokenizer的并行处理,避免死锁等问题的出现(“死锁”好熟悉的词语…一下子回到考研复习操作系统的日子了…555好难过我的梦中情校…),具体为什么要用这句代码GPT给的解释我没太看明白,先继续往后读了

2、def _import_class(module_and_class_name: str) -> type:

这个函数的定义用到了python的类型提示,用于描述函数的输入参数和返回值类型

“module_and_class_name: str”的意思是说这个函数的参数应该是string类型的,就像c定义函数的时候,要限制参数类型那样,比如,float func(int a),这里就明确定义了输入进func函数的参数a是int类型的

“-> type”是说函数的返回值应该是type类型的,这个type类型包括python所有的合法类型

3、module_name, class_name = module_and_class_name.rsplit(".", 1)

使用rsplit函数,以“.”为分隔符对module_and_class_name进行拆分

拆分只进行1次,从module_and_class_name右边开始,从右向左查找分隔符“.”

拆分后得到一个包含两个元素的列表,两个元素分别赋值给module_name和class_name

比如module_and_class_name的值为“a.b.c.d”,对其使用rsplit(".", 1)后,得到[“a.b.c”,”d”],然后a.b.c赋值给module_name,d赋值给class_name

4、def _setup_parser():

这个函数的作用是创建“命令行参数解析器”。命令行参数解析器就是,当你在pycharm终端执行readme文件中的那些命令时,帮你实现这些命令的东西

比如,readme文件中有这么一条命令

这个命令的含义是:执行get_label_word.py 这个文件,model_name_or_path这个参数赋值为bert-large-uncased,dataset_name这个参数赋值为semeval。这些操作都是通过命令行参数解析器实现的

2023.12.30

5、parser = argparse.ArgumentParser(add_help=False)

argparse 是 Python 标准库中的一个模块,用于解析命令行参数和生成帮助信息

ArgumentParser 是 argparse 模块中的一个类,用于创建命令行参数解析器对象

add_help=False:这个参数指示解析器不要自动生成默认的帮助信息。默认情况下,argparse 会为解析器自动添加一个 --help 或 -h 的选项,用于显示帮助信息。用户运行脚本时使用 --help 或 -h 选项,脚本会显示帮助信息并退出,不会执行脚本的主要功能。通过将 add_help 设置为 False,你可以禁用这个默认的帮助选项

6、trainer_parser = pl.Trainer.add_argparse_args(parser)

将 PyTorch Lightning 中的 pl.Trainer 类的参数添加到命令行参数解析器 parser 中。这样一来,就可以在命令行中设置训练的最大周期数、使用的 GPU 数量、训练精度等等。比如:python your_script.py --max_epochs 10 --gpus 2 --precision 16

7、trainer_parser._action_groups[1].title = "Trainer Args"

一开始定义的parser中只有一组参数,使用“trainer_parser = pl.Trainer.add_argparse_args(parser)”后,又将pl.Trainer的参数添加进解析器中,这样一来解析器中就有两组参数了

_action_groups[1]就是代表第二组参数,_action_groups[1].title就是第二组参数的标题

这句代码就是修改第二组参数的标题为"Trainer Args",这样当我们使用—help来查看帮助信息的时候,就能看到第二组参数的标题是(参数不一定是图中的那些参数,举个例子)

8、parser.add_argument("--wandb", action="store_true", default=False)

向命令行参数解析器parser中加入参数“wandb”

action="store_true"的含义是,当命令行包含了特定选项时,参数的值将被设置为 True,否则参数的值将保持默认值,这里默认值通过default=False被设置为了false

举个例子来解释就是:

在命令行使用python my_script.py --wandb执行my_script.py这个文件时,wandb参数的值将被设置为true;如果使用python my_script.py执行my_script.py这个文件时,wandb参数的值将被设置为false

wandb这个参数的作用:

9、parser.add_argument("--litmodel_class", type=str, default="BertLitModel")

向parser中添加参数litmodel_class,这个参数值被“type=str”限制住,只能是字符串类型的,默认这个参数的值是BertLitModel

litmodel_class这个参数的作用:

这个参数用于指定脚本中使用的PyTorch Lightning LitModel模型(Lightning 模型),这个litmodel并不是我们平时所说的“模型”,我们平时所说的“模型”在11中指定,在11中我们也会介绍model和litmodel的区别

10、parser.add_argument("--seed", type=int, default=7)

向命令行参数解析器中添加参数seed,seed代表随机数种子,seed的值必须是int型的,默认值是7

随机数种子的作用是:

在我们训练模型的时候,初始化模型的权重和偏执的时候,不是程序员一个一个的输入这些参数的初始值,而是使用函数随机生成一组数作为初始值。设置了种子后,每次运行代码的时候,模型的权重和偏执的初始值就是一样的。保证了实验的可重复性等等诸多好处

11、parser.add_argument("--model_class", type=str, default="RobertaForPrompt")

指定训练模型使用的模型,默认使用RobertaForPrompt这个模型

model和litmodel的区别

model定义了神经网络的架构(例如,层数、隐藏单元数、注意力机制等)和它们如何处理输入数据(即前向传播逻辑)(也就是model定义了代码实现)

litmodel 通常是指使用 PyTorch Lightning 库创建的模型类。PyTorch Lightning 是一个高级库,旨在简化复杂的模型训练过程。litmodel 类通常封装了一个基础模型(比如 bert-large-uncased 或 roberta),并添加了训练、验证和测试循环的逻辑,以及其他可能需要的步骤,比如优化器和学习率调度器的配置。简单来说,litmodel 更多地关注于如何使用基础模型进行有效的训练和验证

我们在前面的代码中用过“model_name_or_path”这样一个变量,我们给这个变量赋值为roberta-large,那么model_name_or_path定义的值,和model_class定义的值,有什么区别呢?

model_name_or_path,以这里的roberta-large为例,定义了模型的权重和配置。我们将这个模型下载到本地的时候,下载了这四个文件

config.json: 这个文件包含了模型的配置信息,包括模型的架构、超参数、词汇表大小、层数、隐藏单元数等。它描述了模型的整体结构和配置,允许你创建模型实例并了解模型的详细信息。文件内容如图:

merge.txt: 这个文件包含了词汇表中的所有单词的子词(subword)信息,用于分词和文本编码

vocab.json: 这个文件包含了模型的词汇表信息,包括单词和它们的编码

pytorch_model.bin: 这个文件包含了模型的权重参数。它存储了模型的所有权重,包括各个层的权重矩阵、注意力头的权重等等

所以说,可以看到,我们通过model_name_or_path引入的roberta-large其实是不涉及模型的代码实现的,而只是引入了这个模型的一些配置信息

而model_class则是涉及到了模型的代码实现

举例解释model和litmodel的区别,比如我们要训练一个文本分类器:

model就是对输入的文本数据进行理解、转成高维度的特征表示等一系列处理。在这个阶段,我们关心的是模型如何处理输入数据,以及它如何转换成一个特征表示。我们并不涉及如何训练模型,只关注它的内部工作原理

而litmodel是:我们将model模型封装在一个 PyTorch Lightning 模型(即 litmodel)中。这个 litmodel 包含了model,还在其基础上添加了额外的层次,比如一个分类头(分类层),用于基于model的特征输出来做分类。还定义了训练循环(如何使用训练数据来更新模型)、验证循环(如何评估模型的性能)和测试循环。我们还可能定义优化器、学习率调度器等

以上就是model_name_or_path、model_class、litmodel_class这三者的区别了

2024.1.4

12、parser.add_argument("--two_steps", default=True, action="store_true")

添加命令行参数”two-steps”,这个参数的添加方式和参数”wandb”的添加方式一样。但与wandb的默认值为false不同,当命令行中的命令是类似“python my_script.py --wandb”这种时,wandb这个参数的值才为true;而由于two_steps的默认值就是true,所以在命令中不显式使用two_steps时,即使用类似“python my_script.py”这种命令时,two_steps的值也为true

单阶段训练和两阶段训练的区别

13、parser.add_argument("--load_checkpoint", type=str, default=None)

添加命令行参数” load_checkpoint”,这个参数的值必须为字符串,默认值为空

这个参数用于指定要加载的模型的模型检查点文件的路径。模型检查点文件通常包含了已经训练好的模型的权重以及其它必要的信息,可以使用类似“python my_script.py - -load_checkpoint / path / to / checkpoint.pth”这种指令来加载模型检查点文件。使用上述指令时,将从“/ path / to / checkpoint.pth”这个路径中加载模型检查点文件,那么模型将使用这个文件中保存的模型状态来继续进行训练或进行其它任务;如果不加载模型检查点文件则模型从头开始训练

14、parser.add_argument("--freeze", default=True)

freeze参数通常用于冻结某些层的权重,即在训练过程中保持某些层的权重不变

15、parser.add_argument("--num_masks", type=int, default=2)

这个参数的作用是允许用户在命令行中指定一个整数值,该值可能会在脚本中用于控制某些操作或参数设置。脚本可以使用这个值来执行特定的操作或自定义参数设置。如果用户不提供 --num_masks 参数,参数的值将保持默认值

2024.1.10

16、temp_args, _ = parser.parse_known_args()

从命令行参数解析器parser中通过parse_known_args()方法解析出已知参数,然后把结果赋值给temp_args;而“_”表示“不打算使用未知参数的值”

举例解释:

在命令行运行该脚本时

可以看到,我们在命令行参数解析器中只添加了name和age这两个参数,但是传递参数的时候却还传递了other_arg这个参数。那么在运行temp_args, _ = parser.parse_known_args()时,temp_args中就只有name和age这两个参数,other_arg这个参数就会被忽略

debug时也可以看到,从temp_args所代表的命名空间中,可以看到,我们上面向命令行参数解析中添加参数时给的默认值,如图:

17、data_group = parser.add_argument_group("Data Args")

创建一个名为“Data Args”的参数组,并将这个参数组赋值给data_group

然后接下来的“data_class.add_to_argparse(data_group)”这句代码,“add_to_argparse”这个方法下面标黄色波浪线了,说data_class这个类下没有add_to_argparse这个方法。我也不知道是为什么,后面再看吧

18、parser.add_argument("--help", "-h", action="help")

向命令行参数解析器中添加参数“help”或“h”,“help”和“h”都是参数的名字;action="help" 表示当用户在命令行中使用

--help 或 -h 参数时,程序将自动显示解析器中的帮助信息

举例解释:

当用户在命令行中运行 my_program.py 并使用 --help 或 -h 参数时,程序会显示如下帮助信息:

2024.1.11

19-29对应“获取关系嵌入”这个函数的代码,但是这个函数好像没有被用到,这里就不再做更详细的说明了

19、train_dataloader = data.train_dataloader()

从输入的数据对象 data 中获取训练数据集的数据加载器 train_dataloader。数据加载器用于批量加载数据

20、relation_embedding = [[] for _ in range(36)]

创建名为relation_embedding的列表,这个列表中有36个空列表。“_”是一个占位符,表示我们在这里不需要使用循环变量的值

21、model = AutoModel.from_pretrained('bert-base-uncased')

AutoModel 是 Hugging Face Transformers 库中的一个类,用于加载和初始化各种自然语言处理模型。from_pretrained('bert-base-uncased') 是一个方法调用,它告诉模型从 Hugging Face 模型存储库中下载并加载一个名为 'bert-base-uncased' 的预训练 BERT 模型

22、tqdm(train_dataloader)

将迭代过程可视化,通常以进度条的形式显示,以便用户可以看到迭代的进度。这对于监视模型训练进展以及估计训练时间非常有用

23、logits = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids).last_hidden_state.detach().cpu()

使用模型对输入文本进行编码(前向计算),并获取模型的输出logits,这个输出包含了对输入文本的表示,通常用于进一步的任务,例如文本分类或特征提取

input_ids:这是一个张量,其中包含了文本数据的输入标识符(input identifiers)。input_ids 通常是一个整数序列,用于表示文本内容

attention_mask:这是一个张量,用于指示哪些标记需要在模型中进行关注的掩码(mask)。attention_mask 通常是一个二进制序列,其中值为 1 表示关注,值为 0 表示忽略

token_type_ids:这是一个张量,用于表示文本中不同部分的标识符。在某些任务中,文本可能包含多个部分,例如对话中的不同发言者。token_type_ids 通常是一个整数序列,用于区分不同部分

last_hidden_state:这是 BERT 模型的一个属性,它代表模型的最后一个隐藏状态。在 BERT 模型中,最后一个隐藏状态通常包含了对输入文本的编码信息,可以用于进一步的任务或分析

.detach():这是一个 PyTorch 张量的方法,用于分离张量,即将其从计算图中分离出来,使其不再与梯度计算相关,后续对其的操作不会影响梯度计算

.cpu():这是一个 PyTorch 张量的方法,用于将张量移动到 CPU 上

24、_, mask_idx = (input_ids == 103).nonzero(as_tuple=True)

input_ids == 103: 这是一个逐元素比较,返回一个布尔值数组,其中 True 表示 input_ids 中的相应元素等于 103

.nonzero(): 这个方法返回所有 True 元素的索引。在 PyTorch 中,当 as_tuple=True 时,nonzero 方法返回一个元组,其中包含多个数组,每个数组代表了 True 元素在不同维度上的索引

第一个返回值 (_): 这个变量包含了 True 元素在第一个维度(例如,在二维数组中是行)的索引。因为使用了 _,这意味着我们不关心这个信息,不打算在后续代码中使用它

第二个返回值 (mask_idx): 这个变量包含了 True 元素在第二个维度(例如,在二维数组中是列)的索引。在这个场景中,因为我们关心的是等于 103 的元素的位置,mask_idx 就是我们感兴趣的索引集

简而言之,_ 忽略了第一个维度的索引,而 mask_idx 保留了第二个维度的索引

25、mask_output = logits[torch.arange(bs), mask_idx]

logits是模型的输出,是一个张量

torch.arange(bs)是一个一维张量,由整数0 — bs-1 组成

mask_idx是一个一维张量

logits[一维张量,一维张量]是从logits这个张量中选择特定的值组成新的张量。举例:

logits = torch.tensor([

     [0.1, 0.2, 0.3],

     [0.4, 0.5, 0.6],

     [0.7, 0.8, 0.9]

])

torch.arange(bs)=torch.tensor([0,1,2])

mask_idx= torch.tensor([1,0,2])

于是mask_output = logits[torch.tensor([0,1,2]), torch.tensor([1,0,2])]=([ logits[0,1], logits[1,0], logits[2,2] ])= ([ 0.2,0.4,0.9 ])

2024.1.12

26、assert len(labels[0]) == len(relation_embedding)

这行代码的作用是进行断言,检查标签中第一个样本的长度和relation_embedding列表的长度是否相同。如果相同则继续执行代码,不同则抛出AssertionError异常

27、for batch_idx, label in enumerate(labels.tolist()):

labels.tolist() 是一个将 PyTorch 张量(tensor)转换为 Python 列表(list)的操作。具体来说,它会将张量中的元素逐个提取出来,然后存储在一个 Python 列表中

enumerate(labels.tolist()) 是一个用于在迭代列表时获取索引和值的常见操作。具体来说,它会将一个列表中的元素逐个提取出来,并为每个元素生成一个索引-值对(index-value pair)

28、torch.mean(torch.stack(relation_embedding[i]), dim=0)

torch.stack(relation_embedding[i]):这部分代码将 relation_embedding[i] 中的张量堆叠(stack)成一个新的张量。具体来说,如果 relation_embedding[i] 包含多个张量,torch.stack 会将它们按照指定的维度(通常是新的维度0)进行堆叠,创建一个包含这些张量的新张量

torch.mean(..., dim=0):这是一个计算均值的操作,其中 ... 表示上一步骤中得到的堆叠的张量。dim=0 参数指定了在哪个维度上计算均值。在这里,dim=0 表示计算每一列的均值,即在堆叠的张量的每一列上计算均值

29、torch.rand_like(relation_embedding[i-1])

torch.rand_like(relation_embedding[i-1]) 的作用是创建一个与 relation_embedding[i-1] 张量具有相同形状的随机张量,其中每个元素都是随机生成的值。这个操作通常用于初始化或填充一个张量,以便后续的计算或处理。

30、args = parser.parse_args()

使用命令行参数解析器parser的parser_args()方法获取命令函的参数,并将这些参数储存在命名空间对象中,给这个对象起名为args

可以通过“.”来访问命名空间对象中的参数。比如命令行参数包括 --wandb、--lr_2、--data_class 等等,可以通过 args.wandb、args.lr_2、args.data_class 来访问它们的值

需要注意的是,每个py文件都有自己独立的命名空间,a.py中的命令行参数解析器无法访问b.py的命名空间中的变量

31、config = AutoConfig.from_pretrained(args.model_name_or_path)

这行代码用于加载预训练模型的配置

AutoConfig 是 Transformers 库中的一个类,它用于加载各种不同类型的预训练模型(如BERT、GPT-2、RoBERTa等)的配置信息

from_pretrained(args.model_name_or_path) 是一个静态方法,用于从指定的模型名称或路径中加载配置信息。这个方法会根据模型名称或路径自动确定要加载的模型类型,并返回相应的配置。如果参数是模型名,那么就要联网从hugging face上下载模型;如果参数是本地路径,就不用联网了

32、data = data_class(args, model)

args:这是命令行参数,通常包含了用于配置数据加载的参数,例如数据文件的路径、数据预处理方式等

model是上一行创建的模型

这行代码的作用是使用数据集类data_class,按照args中的要求加载数据集、用模型处理数据集,然后赋值给变量data

33、data_config = data.get_data_config()

get_data_config() 是一个自定义方法或属性,用于获取数据集的配置信息。这个方法通常用于返回有关数据集的元数据,例如数据集的特征数、标签数、数据预处理方式等

34、model.resize_token_embeddings(len(data.tokenizer))

resize_token_embeddings是模型用来调整嵌入层大小的方法

len(data.tokenizer)是分词器中词汇表的数量大小

34、(2)、data.tokenizer.save_pretrained('test')

这行代码将data.tokenizer中的分词器相关信息(例如词汇表、模型参数等)保存到名为'test'的本地目录中。这个操作通常用于将已经训练好的分词器保存下来,以便在以后的任务中重用它们,而不必重新训练

35、logger = pl.loggers.TensorBoardLogger("training/logs")

pl.loggers.TensorBoardLogger 是 PyTorch Lightning 提供的一个日志记录器类,用于将训练过程中的日志信息记录到 TensorBoard 格式的日志文件中

"training/logs" 是指定日志文件的保存路径

通过执行这行代码,你创建了一个 TensorBoardLogger 日志记录器对象 logger,该对象可以在训练期间将各种日志信息(如损失、指标、可视化等)记录到指定的日志文件中。TensorBoard 是一个用于可视化和监控训练过程的工具,你可以使用它来查看训练过程中的指标趋势和模型性能

36、dataset_name = args.data_dir.split("/")[-1]

data_dir这个命令行参数是在generate_k_shot.py这个文件中被加入到命名空间的。它的默认值是“../datasets”,就是进入目录结构中的dataset这个文件夹

args.data_dir 是一个包含数据集路径的命令行参数

split("/") 是一个字符串分割操作,它将数据集路径根据斜杠 / 进行分割,将路径拆分为多个部分

[-1] 是索引操作,它用于获取分割后的字符串列表的最后一个元素,即数据集名称。一般数据集路径的最后一部分就是数据集的名字

37、logger = pl.loggers.WandbLogger(project="dialogue_pl", name=f"{dataset_name}")

pl.loggers.WandbLogger 是 PyTorch Lightning 提供的一个日志记录器类,用于将训练过程中的日志信息记录到 Weights and Biases (WandB) 平台上

project="dialogue_pl" 是指定 WandB 项目的名称,用于将日志信息记录到特定的项目中。你可以根据自己的需要更改项目名称

name=f"{dataset_name}" 是指定实验名称,通常是与数据集相关的名称

通过执行这行代码,你创建了一个 WandbLogger 日志记录器对象 logger,该对象可以在训练期间将各种日志信息(如损失、指标、可视化等)记录到 WandB 平台上的指定项目和实验中。WandB 是一个用于实验跟踪、日志记录和可视化的平台,可以帮助你监控和分析模型训练过程中的各种信息

TensorBoardLogger和WandbLogger这两个日志记录器类最大的不同:tensorboard只适用于tensorflow和pytorch框架,而wandb适合所有深度学习框架

38、logger.log_hyperparams(vars(args))

将命令行参数(通常是训练过程的超参数)记录到日志的操作,通常用于实验跟踪和记录模型训练过程中的超参数设置

39、early_callback = pl.callbacks.EarlyStopping(monitor="Eval/f1", mode="max", patience=5,check_on_train_epoch_end=False)

early_callback 是一个用于提前停止训练的回调对象,通常用于避免过拟合或在训练不再改善时节省计算资源

具体来说,这行代码使用 PyTorch Lightning 提供的 EarlyStopping 回调类创建了一个名为 early_callback 的回调对象,其配置如下:

monitor="Eval/f1":这是回调函数监视的指标名称。EarlyStopping 会根据这个指标的值来决定是否停止训练。在这里,它监视 "Eval/f1" 指标,通常是验证集上的 F1 分数(F1分数:在深度学习模型训练过程中一种用于评估模型性能的指标。它通常用于分类任务,特别是二分类任务,例如情感分析、文本分类、图像分类。在训练过程中,模型会在每个训练周期(或称为 epoch)结束后进行验证,以评估其在验证集上的性能。"F1 分数" 是一个综合性能指标,结合了模型的精确度和召回率。计算公式为:(2x查准率x查全率)/(查准率+查全率)。通常情况下,F1 分数的取值范围在 0 到 1 之间,越接近 1 表示模型性能越好。)

mode="max":这是模式参数,指定了如何确定是否达到停止条件。"max" 意味着当监视的指标达到最大值时停止训练。如果你希望在监视指标达到最小值时停止训练,可以将 mode 设置为 "min"

patience=5:这是容忍参数,表示在监视指标没有改善的情况下,等待多少个训练周期后才停止训练。在这里,如果 "Eval/f1" 指标在连续 5 个训练周期中没有改善,则训练将停止

check_on_train_epoch_end=False:这个参数表示是否在每个训练周期结束时检查监视指标。如果设置为 True,则在每个训练周期结束后都会检查一次,如果满足停止条件,就会停止训练。在这里,它设置为 False,表示只在验证集上的评估周期结束时检查监视指标

通过将这个提前停止的回调对象传递给 PyTorch Lightning 的训练器(Trainer),你可以在训练过程中启用提前停止功能,以便在验证集上的性能不再改善时自动停止训练,从而避免过拟合

40、model_checkpoint = pl.callbacks.ModelCheckpoint(monitor="Eval/f1", mode="max", filename='{epoch}-{Eval/f1:.2f}', dirpath="output", save_weights_only=True)

model_checkpoint 是一个用于保存模型检查点的回调对象,通常用于在训练过程中保存模型的权重或整个模型,以便在训练结束后或需要恢复训练时使用

具体来说,这行代码使用 PyTorch Lightning 提供的 ModelCheckpoint 回调类创建了一个名为 model_checkpoint 的回调对象,其配置如下:

monitor="Eval/f1":这是回调函数监视的指标名称。ModelCheckpoint 会根据这个指标的值来决定是否保存模型检查点。在这里,它监视 "Eval/f1" 指标,通常是验证集上的 F1 分数

mode="max":这是模式参数,指定了如何确定是否保存模型检查点。"max" 意味着当监视的指标达到最大值时保存检查点。如果你希望在监视指标达到最小值时保存检查点,可以将 mode 设置为 "min"

filename='{epoch}-{Eval/f1:.2f}':这是保存模型检查点的文件名模板,其中 {epoch} 表示当前训练周期的编号,{Eval/f1:.2f} 表示当前 "Eval/f1" 指标的值,保留两位小数。这将在每个训练周期结束时生成一个唯一的文件名,以保存模型检查点

dirpath="output":这是保存模型检查点的文件夹路径。模型检查点将保存在名为 "output" 的文件夹中。你可以根据需要更改保存路径

save_weights_only=True:这个参数表示是否仅保存模型的权重而不保存整个模型。如果设置为 True,则只保存模型权重;如果设置为 False,则保存整个模型。在这里,它设置为 True,表示只保存权重

通过将这个模型检查点的回调对象传递给 PyTorch Lightning 的训练器(Trainer),你可以在训练过程中启用模型检查点功能,以便在每个训练周期结束时根据指定的条件保存模型的权重或整个模型。这对于在训练中定期保存模型、避免训练中断或用于后续评估和推理非常有用

41、accelerator = "ddp" if gpu_count > 1 else None

根据GPU数量选择适当的训练加速器

如果 gpu_count > 1,即系统中有多个 GPU 可用,那么 accelerator 被设置为 "ddp",表示使用分布式数据并行(Distributed Data Parallel)加速器。Distributed Data Parallel 允许在多个 GPU 上并行训练模型,以加速训练过程

如果 gpu_count <= 1,即系统中只有一个或没有 GPU 可用,那么 accelerator 被设置为 None,表示不使用任何特殊的分布式加速器,而是在单个 GPU 或 CPU 上进行训练

42、trainer = pl.Trainer.from_argparse_args(args,…, default_root_dir="training/logs",…, plugins=DDPPlugin(find_unused_parameters=False) if gpu_count > 1 else None)

创建 PyTorch Lightning 的训练器(Trainer)对象,并配置训练器的各种参数,以便进行模型训练

pl.Trainer.from_argparse_args(args, ...):通过调用 from_argparse_args 方法,从命令行参数 args 中加载训练器的配置参数。这是一种方便的方式,允许从命令行传递参数来配置训练器

default_root_dir="training/logs":指定模型训练过程中保存日志和检查点的根目录

plugins=DDPPlugin(find_unused_parameters=False) if gpu_count > 1 else None:这是一个插件配置,用于根据 GPU 数量选择是否启用 Distributed Data Parallel("ddp")插件。如果 gpu_count > 1,则启用 "ddp" 插件,并设置 find_unused_parameters=False,表示不查找未使用的参数

DDPlugin插件

我们在使用 PyTorch Lightning 的 Trainer 时,已经在 accelerator 参数中设置了正确的分布式加速器(即 "ddp"),所以实际上不需要显式添加 DDPPlugin 插件。accelerator 参数的设置已经包括了 DDP 训练的配置,会自动启用 DDP 并管理 GPU 上的并行训练

43、trainer.fit(lit_model, datamodule=data)

这行代码用于开始模型的训练过程

lit_model:这是你要训练的 PyTorch Lightning 模型,通常LightningModule 的子类,包含模型的定义和训练逻辑

datamodule=data:这是数据模块,用于提供训练和验证数据集以及数据加载器

trainer.fit 方法将使用指定的模型和数据模块来执行训练循环,包括多个训练周期(epochs),每个周期包括数据加载、前向传播、反向传播、参数更新等步骤。具体来说,它会执行以下操作:

2021.1.13

44、config_file_name = time.strftime("%H:%M:%S", time.localtime()) + ".yaml"

这行代码用于生成一个配置文件的文件名,文件名的格式基于当前本地时间

time.strftime("%H:%M:%S") 用于获取当前本地时间,并将其格式化为时:分:秒的字符串。"%H:%M:%S" 是时间格式字符串,表示时、分和秒,比如可能得到的结果“15:30:16”。使用time.strftime("%H:%M:%S", time.localtime())的原因是,不加第二个参数,得到的时间是基于系统默认的时区,可能与我们的时区不同,加入第二个参数是为了得到我们所在时区的时间

.yaml 为文件扩展名

45、trainer.test()

trainer.test() 是 PyTorch Lightning 中的一个方法,用于在训练结束后运行模型的测试阶段。在测试阶段,模型将被用来评估其在测试数据集上的性能,通常用于生成测试结果、计算指标(如准确率、F1 分数等)或者进行其他与模型性能评估相关的操作

具体来说,trainer.test() 的功能包括:将模型切换到评估模式、在测试数据集上运行模型,获取模型的输出、计算并记录各种性能指标(如损失、准确率、F1 分数等)、打印或记录测试结果

46、   

# 从第一阶段过渡到第二阶段的提示

print('args.freeze:', args.freeze)

args.freeze = False

print('args.freeze:', args.freeze)

为什么说两个阶段的切换标志是freeze变量的值由true变为false呢?因为在两阶段训练中

47、lit_model_second.load_state_dict(torch.load(path)["state_dict"])

path是训练过程中性能最佳的模型检查点文件的路径

torch.load(path) 用来加载这个路径

torch.load(path)["state_dict"] 用于取出state_dict这个键对应的值。state_dict这个键是代表模型权重参数的键

  • 11
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值