背景
本人经常会阅读苏神的科学空间网站,里面有很多对前言paper浅显易懂的解释,以及很多苏神自己的创新实践;并且基于bert4keras框架都有了相应的代码实现。但是由于本人主要用pytorch开发,因此参考bert4keras开发了bert4torch项目,实现了bert4keras的主要功能。
简介
bert4torch是一个基于pytorch的训练框架,前期以效仿和实现bert4keras的主要功能为主,方便加载多类预训练模型进行finetune,提供了中文注释方便用户理解模型结构。主要是期望应对新项目时,可以直接调用不同的预训练模型直接finetune,或方便用户基于bert进行魔改,快速验证自己的idea;节省在github上clone各种项目耗时耗力,且本地文件各种copy的问题。
- pip安装
pip install bert4torch
主要功能
1、加载预训练权重(bert、roberta、albert、nezha、bart、RoFormer、ELECTRA、GPT、GPT2、T5)继续进行finetune
目前支持的预训练模型一览
2、在bert基础上灵活定义自己模型:主要是可以接在bert的[btz, seq_len, hdsz]的隐含层向量后做各种魔改
3、调用方式和bert4keras基本一致,简洁高效
model.fit(
train_dataloader,
steps_per_epoch=1000,
epochs=epochs,
callbacks=[evaluator]
)
4、实现基于keras的训练进度条动态展示
仿照keras的模型训练进度条
5、配合torchinfo,实现打印各层参数量功能
打印参数
6、结合logger,或者tensorboard可以在后台打印日志
支持在训练开始/结束,batch开始/结束,epoch的开始/结束,记录日志,写tensorboard等
class Callback(object):
'''Callback基类
'''
def __init__(self):
pass
def on_train_begin(self, logs=None):
pass
def on_train_end(self, logs=None):
pass
def on_epoch_begin(self, global_step, epoch, logs=None):
pass
def on_epoch_end(self, global_step, epoch, logs=None):
pass
def on_batch_begin(self, global_step, batch, logs=None):
pass
def on_batch_end(self, global_step, batch, logs=None):
pass
7、集成多个example,可以作为自己的训练框架,方便在同一个数据集上尝试多种解决方案
实现多个example可供参考
未来计划
- Transformer-XL、XLnet等其他网络架构
- 前沿的各类模型idea实现,如苏神科学空间网站的诸多idea