bert4torch又双叒叕更新啦!新功能浅析~

一、背景

bert4torch是一款简洁的训练框架,经过半年的维护和使用已经越发完善,近期的工作主要是增加了很多实战示例,拿来就用还是很香了。不了解bert4torch可以通过前述两篇文章来浅尝一下~

bert4torch(参考bert4keras的pytorch实现)15 赞同 · 9 评论文章

bert4torch快速上手16 赞同 · 3 评论文章

二、主要功能复述

三、新增功能简介

3.1 更新日志

  • 2022年8月28更新:增加nl2sql示例, 增加自定义metrics,支持断点续训
  • 2022年8月21更新:增加W2NER和DiffCSE示例,打印Epoch开始的时间戳,增加parallel_apply, 兼容torch<=1.7.1的torch.div无rounding_mode
  • 2022年8月14更新:增加有监督句向量、关系抽取、文本生成实验指标,兼容torch<1.9.0的缺失take_along_dim,修复bart中位置向量514的问题,修复Sptokenizer对符号不转换
  • 2022年7月27更新:增加mixup/manifold_mixup/temporal_ensembling策略, 修复pgd策略param.grad为空的问题,修改tokenizer支持批量,增加uie示例
  • 2022年7月16更新:修复原来CRF训练中loss陡增的问题,修复xlnet的token_type_ids输入显存占用大的问题
  • 2022年7月10更新:增加金融中文FAQ示例,天池新闻分类top1案例,增加EarlyStop,CRF中自带转bool类型
  • 2022年6月29更新:增加ner的实验,测试crf不同初始化的效果,bert-whitening中文实验
  • 2022年6月13更新:增加seq2seq+前缀树,增加SimCSE/ESimCSE/PromptBert等无监督语义相似度的中文实验
  • 2022年6月05更新:增加PromptBert、PET、P-tuning示例,修改tokenizer对special_tokens分词错误的问题,增加t5_pegasus
  • 2022年5月29更新:transformer_xl、xlnet模型, 修改sinusoid位置向量被init_weight的bug, EMA,sohu情感分类示例
  • 2022年5月17更新:增加预训练代码,支持增加embedding输入(如词性,word粒度embedding)

3.2 主要新增功能

  • 新增预训练模型

新增了xlnet和t5_pegasus两个预训练模型

  • 支持增加额外embedding输入

比如在做实体提取时候,想尝试下加入token的词性来看看是否能提升模型效果,这个时候就需要增加额外的embedding,使用时候直接传入layer_add_embs参数即可

build_transformer_model(
    config_path=config_path, # 模型的config文件地址
    checkpoint_path=checkpoint_path, # 模型文件地址,默认值None表示不加载预训练模型
    model='bert', # 加载的模型结构,这里Model也可以基于nn.Module自定义后传入
    application='encoder',  # 模型应用,支持encoder,lm和unilm格式
    segment_vocab_size=2,  # type_token_ids数量,默认为2,如不传入segment_ids则需设置为0
    with_pool=False,  # 是否包含Pool部分
    with_nsp=False,  # 是否包含NSP部分
    with_mlm=False,  # 是否包含MLM部分
    return_model_config=False,  # 是否返回模型配置参数
    output_all_encoded_layers=False,  # 是否返回所有hidden_state层
    layer_add_embs=nn.Embedding(2, 768),  # 自定义额外的embedding输入
)
  • 增加自定义metrics

在训练过程中想打印一些指标来观测训练集上的指标(默认会打印loss),很多时候这些指标还需要自定义,参考keras的实现,目前bert4torch也支持了,使用方式如下

'''
定义使用的loss、optimizer和metrics,这里支持自定义
'''
def eval(y_pred, y_true):
    # 仅做示意
    return {'rouge-1': random.random(), 'rouge-2': random.random(), 'rouge-l': random.random(), 'bleu': random.random()}

def f1(y_pred, y_true):
    # 仅做示意
    return random.random()

model.compile(
    loss=nn.CrossEntropyLoss(), # 可以自定义Loss
    optimizer=optim.Adam(model.parameters(), lr=2e-5),  # 可以自定义优化器
    scheduler=None, # 可以自定义scheduler
    adversarial_train={'name': 'fgm'},  # 训练trick方案设置,支持fgm, pgd, gradient_penalty, vat
    metrics=['accuracy', eval, {'f1': f1}]  # loss等默认打印的字段无需设置,可多种方式自定义回调函数
)
  • 断点续训

很多时候由于各种原因(显存不足,意外断电)等情况,虽然你保存的最优模型,但是你的优化器没保存,导致无法接着训练,bert4torch内置了简单的小函数,即可方便使用断点续训,训练进度条也会从上次断掉的地方重新开始记录

# =======断点续训========
# 在Callback中的on_epoch_end()或on_batch_end()保存需要的参数
model.save_weights(save_path, prefix=None)  # 保存模型权重
model.save_steps_params(save_path)  # 保存训练进度参数,当前的epoch和step,断点续训使用
torch.save(optimizer.state_dict(), save_path)  # 保存优化器,断点续训使用

# 加载前序训练保存的参数
model.load_weights(save_path)  # 加载模型权重
model.load_steps_params(save_path)  # 加载训练进度参数,断点续训使用
state_dict = torch.load(save_path, map_location='cpu')  # 加载优化器,断点续训使用
optimizer.load_state_dict(state_dict)
  • 额外小细节

1. 每个Epoch会打印时间戳,方便查看训练的起止时间(仅仅记录训练时长总是需要换算)

打印Epoch同时记录时间戳

2. 句向量的获取简单配置即可get_pool_emb(hidden_state=None, pooler=None, attention_mask=None, pool_strategy='cls', custom_layer=None)

def get_pool_emb(hidden_state=None, pooler=None, attention_mask=None, pool_strategy='cls', custom_layer=None):
    ''' 获取句向量
    '''
    if pool_strategy == 'pooler':
        return pooler
    elif pool_strategy == 'cls':
        if isinstance(hidden_state, (list, tuple)):
            hidden_state = hidden_state[-1]
        assert isinstance(hidden_state, torch.Tensor), f'{pool_strategy} strategy request tensor hidden_state'
        return hidden_state[:, 0]
    elif pool_strategy in {'last-avg', 'mean'}:
        if isinstance(hidden_state, (list, tuple)):
            hidden_state = hidden_state[-1]
        assert isinstance(hidden_state, torch.Tensor), f'{pool_strategy} pooling strategy request tensor hidden_state'
        hid = torch.sum(hidden_state * attention_mask[:, :, None], dim=1)
        attention_mask = torch.sum(attention_mask, dim=1)[:, None]
        return hid / attention_mask
    elif pool_strategy in {'last-max', 'max'}:
        if isinstance(hidden_state, (list, tuple)):
            hidden_state = hidden_state[-1]
        assert isinstance(hidden_state, torch.Tensor), f'{pool_strategy} pooling strategy request tensor hidden_state'
        hid = hidden_state * attention_mask[:, :, None]
        return torch.max(hid, dim=1)
    elif pool_strategy == 'first-last-avg':
        assert isinstance(hidden_state, list), f'{pool_strategy} pooling strategy request list hidden_state'
        hid = torch.sum(hidden_state[1] * attention_mask[:, :, None], dim=1) # 这里不取0
        hid += torch.sum(hidden_state[-1] * attention_mask[:, :, None], dim=1)
        attention_mask = torch.sum(attention_mask, dim=1)[:, None]
        return hid / (2 * attention_mask)
    elif pool_strategy == 'custom':
        # 取指定层
        assert isinstance(hidden_state, list), f'{pool_strategy} pooling strategy request list hidden_state'
        assert isinstance(custom_layer, (int, list, tuple)), f'{pool_strategy} pooling strategy request int/list/tuple custom_layer'
        custom_layer = [custom_layer] if isinstance(custom_layer, int) else custom_layer
        hid = 0
        for i, layer in enumerate(custom_layer, start=1):
            hid += torch.sum(hidden_state[layer] * attention_mask[:, :, None], dim=1)
        attention_mask = torch.sum(attention_mask, dim=1)[:, None]
        return hid / (i * attention_mask)
    else:
        raise ValueError('pool_strategy illegal')

3. 全局seed,固定随机种子一般会写几行简单的代码,但是太常用了,参考pytorch_lightning使用seed_everything(seed)来固定随机数

def seed_everything(seed=None):
    '''固定seed
    '''
    max_seed_value = np.iinfo(np.uint32).max
    min_seed_value = np.iinfo(np.uint32).min

    if (seed is None) or not (min_seed_value <= seed <= max_seed_value):
        random.randint(np.iinfo(np.uint32).min, np.iinfo(np.uint32).max)
    print(f"Global seed set to {seed}")
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    return seed

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值