# bert
from transformers import BertModel, BertConfig
config = BertConfig.from_json_file('bert-base/config.json')
bert_model = BertModel(config, add_pooling_layer=True)
pytorch_total_params = sum(p.numel() for p in bert_model.parameters() if p.requires_grad)
print('模型参数量: ', pytorch_total_params)
# gpt
from transformers import GPT2Config, GPT2Model
config = GPT2Config.from_json_file('gpt2-config.json')
model = GPT2Model(config)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('模型参数量: ', pytorch_total_params)
# gpt