【Transformers源码】快速debug model tips
debug源码可以更快速高效的学习model,是不是要等模型文件下载完,才能debug,等待下载LLM的模型文件是不是内心有那么亿丢丢的焦灼…
一、蠢哭自己的方法
笔者一直用的蠢哭了的方法,分为三步
- huggingface官网下载model文件
- 写代码加载模型
- debug源码
from transformers import AutoModel, AutoTokenizer # 加载模型和分词器 model_path = "****/chinese-roberta-wwm-ext" model = AutoModel.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) # 处理文本 input_txt = "甄天真真天真" encode_dict = tokenizer( [input_txt], return_tensors="pt", ) # 模型推理 model(**encode_dict)
缺点:
- 下载了一堆模型文件,是不是占电脑存储空间
- 小模型等等也不是很要紧,下载LLM是不是感受到了大把好时光在流逝
二、跟电子导师学到的方法【model-free】
强裂推荐笔者的电子导师大佬的方法transformers源码阅读——llama模型调试
再也不用等模型文件下载完,就可以直接debug模型源码啦~~~
BertModel
from transformers.models.bert import BertConfig, BertModel
import torch
def run_bert():
bert_config = BertConfig()
bert_model = BertModel(config=bert_config)
# batch_size=4, seq_len=30
input_ids = torch.randint(low=0, high=bert_config.vocab_size, size=(4, 30))
res = bert_model(input_ids)
print(res)
if __name__ == "__main__":
run_bert()
LlamaModel
尤其在大模型学习中,嘎嘎提效👍
from transformers.models.llama import LlamaConfig, LlamaModel
import torch
def run_llama():
llamaConfig = LlamaConfig(
vocab_size=32000,
hidden_size=4096 // 2,
intermediate_size=11008 // 2,
num_hidden_layers=32 // 2,
num_attention_heads=32 // 2,
max_position_embeddings=2048 // 2,
)
llamamodel = LlamaModel(config=llamaConfig)
# 构建输入
input_ids = torch.randint(low=0, high=llamaConfig.vocab_size, size=(4, 30))
res = llamamodel(input_ids)
print(res)
if __name__ == "__main__":
run_llama()