读源码之SimBertv2-stage1

事先声明:我只是个真·零基础小白,所以部分理解不到位或者有错误,望各位大佬不吝赐教!

第一部分 引入库部分

在这里插入图片描述

本代码采用的是苏老师写的bert4keras,即使用keras实现bert,包含层、模型、优化器、分词器等

bert4keras最好在tensorflow<=2.2以及keras<=2.3.1的条件下,即搭配python3.6食用

第二部分 加载BERT模型

BERT的结构是这样的(可以参考这篇文章:原来你是这样的BERT,i了i了!

simBERTv2也是基于BERT的一个模型,所以需要先“搭建”一个BERT模型【此时我们只有最基本的bert配置】

我们先不看def/class的内容,直接快进到建立加载模型

在这里插入图片描述

然后我们去寻找build_transformer_model的源码(位于bert4keras\models.py),这个函数是负责建立一个模型(如果检查点存在,就会用load_weights_from_checkpoint加载检查点中存储的内容)

下面,我们来看看这些参数都是做什么的:

  1. config_path,checkpoint_path:变量如其名,即加载配置和检查点的路径
  2. model:即模型类型,bert4keras支持导入很多种模型,如果感兴趣的话可以前往bert4keras的models.py文件中的build_transformer_model()函数下查找
  3. application:从字面上看不出来是干什么的,我们返回*build_transformer_model()*中查看
application = application.lower()
if application in ['lm', 'unilm'] and model in ['electra', 't5']:
    raise ValueError(
        '"%s" model can not be used as "%s" application.\n' %
        (model, application)
    )

if application == 'lm':
    MODEL = extend_with_language_model(MODEL)
elif application == 'unilm':
    MODEL = extend_with_unified_language_model(MODEL)

还不懂,找到extend_with_language_model()函数查看,苏老师在这里写了备注:

lm给其他语言模型使用,ulm(unified language model)给seq2seq模型用

  1. 其他参数:实际上我们并没有在build_transformer_model()中找到对应的参数,但是kwargs这个参数就很耐人寻味了,我们在models.py中发现,这个参数在*apply()*中层的建立,encoder/decoder的构建中使用过

    with_pool,with_mlm:可以完全参照字面意思,具体见这篇文章【NLP】bert4keras源码及矩阵计算解析

模型的剩余部分由这里建立

encoder = keras.models.Model(roformer.inputs, roformer.outputs[0])#bert中的encoder
seq2seq = keras.models.Model(roformer.inputs, roformer.outputs[1])#seq2seq用来计算损失

outputs = TotalLoss([2, 3])(roformer.inputs + roformer.outputs)#损失
model = keras.models.Model(roformer.inputs, outputs)#模型建立

第三部分 模型训练的准备

AdamW = extend_with_weight_decay(Adam, 'AdamW')
optimizer = AdamW(learning_rate=1e-5, weight_decay_rate=0.01)
model.compile(optimizer=optimizer)
model.summary()

model.complie用于配置训练的优化器、损失函数和准确率评测标准,详见tensorflow中model.compile()用法

第四部分 模型训练

train_generator = data_generator(corpus(), batch_size)
evaluator = Evaluate()

model.fit_generator(
    train_generator.forfit(),
    steps_per_epoch=steps_per_epoch,
    epochs=epochs,
    callbacks=[evaluator]
)

其中,data_generator用来生成数据,而我们虽然使用的语料是汉语文字,喂给真正BERT的还需要进行一些编码,即需要分词(Tokenizer);训练模型,喂给模型的不全是完整的句子,有时候需要遮住一些词喂给模型,而函数mask_encode起到的就是这个作用,苏老师在[博客](SimBERTv2来了!融合检索和生成的RoFormer-Sim模型 - 科学空间|Scientific Spaces)中的【生成】部分阐述了这个思想——即BART;

data_generator还涉及了数据蒸馏的部分,但是暂时我还学到这里,所以暂且搁置了

剩下的fit中则是一些比较常规的内容

训练语料corpus部分还是很容易懂的,有时间再写

第五部分 模型保存

训练好的模型会在每个epoch的最后保存下来,使用的方法为save_weights

def on_epoch_end(self, epoch, logs=None):
    model.save_weights('./latest_model.weights')
    # 保存最优
    if logs['loss'] <= self.lowest:
        self.lowest = logs['loss']
        model.save_weights('./best_model.weights')
        # 演示效果
        just_show()

加载模型参数则采用load_weights

-----以上,stage1使用语料进行训练,并进行模型的保存,就结束了-----

然而,我们的models.py中,如果要加载检查点的权重,则要使用:

transformer.load_weights_from_checkpoint(checkpoint_path)

这样则会造成一些冲突,即你通过stage1保存的模型,是无法直接使用bert_transformer_model()的

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值