事先声明:我只是个真·零基础小白,所以部分理解不到位或者有错误,望各位大佬不吝赐教!
第一部分 引入库部分
本代码采用的是苏老师写的bert4keras,即使用keras实现bert,包含层、模型、优化器、分词器等
bert4keras最好在tensorflow<=2.2以及keras<=2.3.1的条件下,即搭配python3.6食用
第二部分 加载BERT模型
BERT的结构是这样的(可以参考这篇文章:原来你是这样的BERT,i了i了!)
simBERTv2也是基于BERT的一个模型,所以需要先“搭建”一个BERT模型【此时我们只有最基本的bert配置】
我们先不看def/class的内容,直接快进到建立加载模型
然后我们去寻找build_transformer_model的源码(位于bert4keras\models.py),这个函数是负责建立一个模型(如果检查点存在,就会用load_weights_from_checkpoint加载检查点中存储的内容)
下面,我们来看看这些参数都是做什么的:
- config_path,checkpoint_path:变量如其名,即加载配置和检查点的路径
- model:即模型类型,bert4keras支持导入很多种模型,如果感兴趣的话可以前往bert4keras的models.py文件中的build_transformer_model()函数下查找
- application:从字面上看不出来是干什么的,我们返回*build_transformer_model()*中查看
application = application.lower()
if application in ['lm', 'unilm'] and model in ['electra', 't5']:
raise ValueError(
'"%s" model can not be used as "%s" application.\n' %
(model, application)
)
if application == 'lm':
MODEL = extend_with_language_model(MODEL)
elif application == 'unilm':
MODEL = extend_with_unified_language_model(MODEL)
还不懂,找到extend_with_language_model()函数查看,苏老师在这里写了备注:
lm给其他语言模型使用,ulm(unified language model)给seq2seq模型用
-
其他参数:实际上我们并没有在build_transformer_model()中找到对应的参数,但是kwargs这个参数就很耐人寻味了,我们在models.py中发现,这个参数在*apply()*中层的建立,encoder/decoder的构建中使用过
with_pool,with_mlm:可以完全参照字面意思,具体见这篇文章【NLP】bert4keras源码及矩阵计算解析
模型的剩余部分由这里建立
encoder = keras.models.Model(roformer.inputs, roformer.outputs[0])#bert中的encoder
seq2seq = keras.models.Model(roformer.inputs, roformer.outputs[1])#seq2seq用来计算损失
outputs = TotalLoss([2, 3])(roformer.inputs + roformer.outputs)#损失
model = keras.models.Model(roformer.inputs, outputs)#模型建立
第三部分 模型训练的准备
AdamW = extend_with_weight_decay(Adam, 'AdamW')
optimizer = AdamW(learning_rate=1e-5, weight_decay_rate=0.01)
model.compile(optimizer=optimizer)
model.summary()
model.complie用于配置训练的优化器、损失函数和准确率评测标准,详见tensorflow中model.compile()用法
第四部分 模型训练
train_generator = data_generator(corpus(), batch_size)
evaluator = Evaluate()
model.fit_generator(
train_generator.forfit(),
steps_per_epoch=steps_per_epoch,
epochs=epochs,
callbacks=[evaluator]
)
其中,data_generator用来生成数据,而我们虽然使用的语料是汉语文字,喂给真正BERT的还需要进行一些编码,即需要分词(Tokenizer);训练模型,喂给模型的不全是完整的句子,有时候需要遮住一些词喂给模型,而函数mask_encode起到的就是这个作用,苏老师在[博客](SimBERTv2来了!融合检索和生成的RoFormer-Sim模型 - 科学空间|Scientific Spaces)中的【生成】部分阐述了这个思想——即BART;
data_generator还涉及了数据蒸馏的部分,但是暂时我还学到这里,所以暂且搁置了
剩下的fit中则是一些比较常规的内容
训练语料corpus部分还是很容易懂的,有时间再写
第五部分 模型保存
训练好的模型会在每个epoch的最后保存下来,使用的方法为save_weights
def on_epoch_end(self, epoch, logs=None):
model.save_weights('./latest_model.weights')
# 保存最优
if logs['loss'] <= self.lowest:
self.lowest = logs['loss']
model.save_weights('./best_model.weights')
# 演示效果
just_show()
加载模型参数则采用load_weights
-----以上,stage1使用语料进行训练,并进行模型的保存,就结束了-----
然而,我们的models.py中,如果要加载检查点的权重,则要使用:
transformer.load_weights_from_checkpoint(checkpoint_path)
这样则会造成一些冲突,即你通过stage1保存的模型,是无法直接使用bert_transformer_model()的