【Python】科研代码学习:四 AutoClass,Logging

本文详细阐述了在HuggingFaceTransformers中使用AutoClass自动加载模型、AutoConfig配置模型参数和Logging进行日志记录的方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

AutoClass

  • .from_pretrained() 以及中间提供的模型名参数,比如 model = AutoModel.from_pretrained("google-bert/bert-base-cased")AutoModel 就可以自动匹配最合适的内容,比如这里就是 BertModel
  • AutoClass 里面的自动类特别多,主要分为
    1)通用自动类
    在这里插入图片描述
    2)具体任务对应的自动类
    太多了,截图了部分
    在这里插入图片描述

AutoConfig

  • 最重要的参数就是如图第一个参数
    可以提供模型名,或模型路径(到save_pretrained()路径,或configuration.json)
    在这里插入图片描述

The configuration class to instantiate is selected based on the model_type property of the config object that is loaded, or when it’s missing, by falling back to using pattern matching on pretrained_model_name_or_path

  • 也就是说,自动加载到哪个类,首先取决于 model_type,若它缺失,再根据提供的上述参数进行模式匹配
    下面列了几个比较常用的模型映射
bark — BarkConfig (Bark model)
bart — BartConfig (BART model)
bert — BertConfig (BERT model)
bloom — BloomConfig (BLOOM model)
falcon — FalconConfig (Falcon model)
gpt2 — GPT2Config (OpenAI GPT-2 model)
llama — LlamaConfig (LLaMA model)
t5 — T5Config (T5 model)
... 具体看官方API文档
  • 例子
  • AutoConfig.from_pretrained() 来获得对应的Config
from transformers import AutoConfig

# Download configuration from huggingface.co and cache.
config = AutoConfig.from_pretrained("google-bert/bert-base-uncased")

# Download configuration from huggingface.co (user-uploaded) and cache.
config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased")

# If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*).
config = AutoConfig.from_pretrained("./test/bert_saved_model/")

# Load a specific configuration file.
config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json")

# Change some config attributes when loading a pretrained config.
config = AutoConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
config.output_attentions

config, unused_kwargs = AutoConfig.from_pretrained(
    "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
)
config.output_attentions

unused_kwargs

AutoTokenizer

  • AutoConfig 同理,不赘述了
    在这里插入图片描述
  • AutoTokenizer.from_pretrained() 来获得对应的Tokenizer
from transformers import AutoTokenizer

# Download vocabulary from huggingface.co and cache.
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")

# Download vocabulary from huggingface.co (user-uploaded) and cache.
tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")

# If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
# tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/")

# Download vocabulary from huggingface.co and define model-specific arguments
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base", add_prefix_space=True)

AutoModel

  • 加载 AutoModel 有如下的方式
  • from_config
    注意:from_config 只加载模型的配置configuration,不加载模型的权重 weights!
from transformers import AutoConfig, AutoModel

# Download configuration from huggingface.co and cache.
config = AutoConfig.from_pretrained("google-bert/bert-base-cased")
model = AutoModel.from_config(config)
  • from_pretrained
    注意:from_pretrained 加载模型的权重 weights
from transformers import AutoConfig, AutoModel

# Download model and configuration from huggingface.co and cache.
model = AutoModel.from_pretrained("google-bert/bert-base-cased")

# Update configuration during loading
model = AutoModel.from_pretrained("google-bert/bert-base-cased", output_attentions=True)

# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
model = AutoModel.from_pretrained(
    "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
)
  • 别忘了加载后,添加 model.eval() 或者 model.train()
  • 当然还有对应的 TFAutoModelFlaxAutoModel,不赘述了

AutoModelForPreTraining

  • 加载用于预训练 PreTraining 的模型
    也就是带有预训练头(PreTraining Head)
  • 同理,通过 from_config / from_pretrained 加载
    用于预训练的模型参数比上述 AutoModel 的参数较少,请参考官方文档
  • 同理,也有对应的 TFAutoModelForPreTraining / FlaxAutoModelForPreTraining

AutoModelForCausalLM

  • 接下来看一下 NLP 领域中的一些AutoClass
    AutoModelForCausalLM 加载的模型,带有因果语言模型头 (with a causal language modeling head)
    它就可以专注做NLP中自回归的语言任务
  • 同理,通过 from_config / from_pretrained 加载
  • 同理,也有对应的 TFAutoModelForCausalLM / FlaxAutoModelForCausalLM

Causal Language Model

  • 注意不是 Casual 随意的,是 Causal 因果的
    文本生成系列之因果语言模型
    与MLM掩码语言模型相反
    因果语言模型采用了对角掩蔽矩阵,使得每个token只能看到在它之前的token信息,而看不到在它之后的token,模型的训练目标是根据在这之前的token来预测下一个位置的token。

AutoModelForMaskedLM

  • 那如果不是自回归任务,而是掩码语言任务,那就用 AutoModelForMaskedLM / TFAutoModelForMaskedLM / FlaxAutoModelForMaskedLM

AutoModelForSeq2SeqLM

  • 加载此模型,会有一个s2s语言模型头 (with a sequence-to-sequence language modeling head)
    主要是给 seq2seq模型用的
  • 同理,通过 from_config / from_pretrained 加载
  • 同理,也有对应的 TFAutoModelForSeq2SeqLM / FlaxAutoModelForSeq2SeqLM

其他一些具体任务的AutoClass

  • AutoModelForSequenceClassification,做序列分类用的 (with a sequence classification head)
  • AutoModelForMultipleChoice,做多选用的 (with a multiple choice head)
  • AutoModelForNextSentencePrediction,做NSP用的 (with a next sentence prediction head)
  • AutoModelForTokenClassification,做Token分类用的 (with a token classification head)
  • AutoModelForQuestionAnswering,做QA用的 (with a question answering head)
  • CV等其他方向不赘述了

Logging

  • HF官网API:Logging
    Logging 主要是方面输出日志记录用的
    python 也有logging库,比如 import logging 可以直接导入
    HF 也提供了logging库,导入为 from transformers.utils import logging,看一下例子
from transformers.utils import logging

logging.set_verbosity_info()
logger = logging.get_logger("transformers")
logger.critic("Here is a CRITIC Report")
logger.error("反正里面是日志信息,自己看着填")
logger.warning("WARN")
logger.info("INFO")
logger.debug("DEBUG")

在这里插入图片描述

  • 看源代码,它确实导入了 import logging ,但也有与 huggingface_hub.utilstqdm 做了一些交互。
  • 使用不同程度的日志类型记录,会更方面日志管理,比直接 print("Debug") 更规范
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值