【Python】科研代码学习:四 AutoClass,Logging
前言
AutoClass
- 从
.from_pretrained()
以及中间提供的模型名参数,比如model = AutoModel.from_pretrained("google-bert/bert-base-cased")
,AutoModel
就可以自动匹配最合适的内容,比如这里就是BertModel
- AutoClass 里面的自动类特别多,主要分为
1)通用自动类
2)具体任务对应的自动类
太多了,截图了部分
AutoConfig
- 最重要的参数就是如图第一个参数
可以提供模型名,或模型路径(到save_pretrained()路径,或configuration.json)
The configuration class to instantiate is selected based on the model_type property of the config object that is loaded, or when it’s missing, by falling back to using pattern matching on pretrained_model_name_or_path
- 也就是说,自动加载到哪个类,首先取决于
model_type
,若它缺失,再根据提供的上述参数进行模式匹配
下面列了几个比较常用的模型映射
bark — BarkConfig (Bark model)
bart — BartConfig (BART model)
bert — BertConfig (BERT model)
bloom — BloomConfig (BLOOM model)
falcon — FalconConfig (Falcon model)
gpt2 — GPT2Config (OpenAI GPT-2 model)
llama — LlamaConfig (LLaMA model)
t5 — T5Config (T5 model)
... 具体看官方API文档
- 例子
AutoConfig.from_pretrained()
来获得对应的Config
from transformers import AutoConfig
# Download configuration from huggingface.co and cache.
config = AutoConfig.from_pretrained("google-bert/bert-base-uncased")
# Download configuration from huggingface.co (user-uploaded) and cache.
config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased")
# If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*).
config = AutoConfig.from_pretrained("./test/bert_saved_model/")
# Load a specific configuration file.
config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json")
# Change some config attributes when loading a pretrained config.
config = AutoConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
config.output_attentions
config, unused_kwargs = AutoConfig.from_pretrained(
"google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
)
config.output_attentions
unused_kwargs
AutoTokenizer
- 与
AutoConfig
同理,不赘述了
AutoTokenizer.from_pretrained()
来获得对应的Tokenizer
from transformers import AutoTokenizer
# Download vocabulary from huggingface.co and cache.
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
# Download vocabulary from huggingface.co (user-uploaded) and cache.
tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
# If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
# tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/")
# Download vocabulary from huggingface.co and define model-specific arguments
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base", add_prefix_space=True)
AutoModel
- 加载
AutoModel
有如下的方式 from_config
注意:from_config
只加载模型的配置configuration,不加载模型的权重 weights!
from transformers import AutoConfig, AutoModel
# Download configuration from huggingface.co and cache.
config = AutoConfig.from_pretrained("google-bert/bert-base-cased")
model = AutoModel.from_config(config)
from_pretrained
注意:from_pretrained
加载模型的权重 weights
from transformers import AutoConfig, AutoModel
# Download model and configuration from huggingface.co and cache.
model = AutoModel.from_pretrained("google-bert/bert-base-cased")
# Update configuration during loading
model = AutoModel.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
model = AutoModel.from_pretrained(
"./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
)
- 别忘了加载后,添加
model.eval()
或者model.train()
- 当然还有对应的
TFAutoModel
和FlaxAutoModel
,不赘述了
AutoModelForPreTraining
- 加载用于预训练 PreTraining 的模型
也就是带有预训练头(PreTraining Head) - 同理,通过
from_config
/from_pretrained
加载
用于预训练的模型参数比上述AutoModel
的参数较少,请参考官方文档 - 同理,也有对应的
TFAutoModelForPreTraining
/FlaxAutoModelForPreTraining
AutoModelForCausalLM
- 接下来看一下 NLP 领域中的一些AutoClass
AutoModelForCausalLM
加载的模型,带有因果语言模型头 (with a causal language modeling head)
它就可以专注做NLP中自回归的语言任务 - 同理,通过
from_config
/from_pretrained
加载 - 同理,也有对应的
TFAutoModelForCausalLM
/FlaxAutoModelForCausalLM
Causal Language Model
- 注意不是 Casual 随意的,是 Causal 因果的
文本生成系列之因果语言模型
与MLM掩码语言模型相反
因果语言模型采用了对角掩蔽矩阵,使得每个token只能看到在它之前的token信息,而看不到在它之后的token,模型的训练目标是根据在这之前的token来预测下一个位置的token。
AutoModelForMaskedLM
- 那如果不是自回归任务,而是掩码语言任务,那就用
AutoModelForMaskedLM
/TFAutoModelForMaskedLM
/FlaxAutoModelForMaskedLM
AutoModelForSeq2SeqLM
- 加载此模型,会有一个s2s语言模型头 (with a sequence-to-sequence language modeling head)
主要是给 seq2seq模型用的 - 同理,通过
from_config
/from_pretrained
加载 - 同理,也有对应的
TFAutoModelForSeq2SeqLM
/FlaxAutoModelForSeq2SeqLM
其他一些具体任务的AutoClass
AutoModelForSequenceClassification
,做序列分类用的 (with a sequence classification head)AutoModelForMultipleChoice
,做多选用的 (with a multiple choice head)AutoModelForNextSentencePrediction
,做NSP用的 (with a next sentence prediction head)AutoModelForTokenClassification
,做Token分类用的 (with a token classification head)AutoModelForQuestionAnswering
,做QA用的 (with a question answering head)- CV等其他方向不赘述了
Logging
- HF官网API:Logging
Logging 主要是方面输出日志记录用的
python 也有logging库,比如import logging
可以直接导入
HF 也提供了logging库,导入为from transformers.utils import logging
,看一下例子
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
logger.critic("Here is a CRITIC Report")
logger.error("反正里面是日志信息,自己看着填")
logger.warning("WARN")
logger.info("INFO")
logger.debug("DEBUG")
- 看源代码,它确实导入了
import logging
,但也有与huggingface_hub.utils
、tqdm
做了一些交互。 - 使用不同程度的日志类型记录,会更方面日志管理,比直接
print("Debug")
更规范