BertEmbedding的各种用法


向AI转型的程序员都关注了这个号????????????

机器学习AI算法工程   公众号:datayx


bert自从在 BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 中被提出后,因其性能卓越受到了极大的关注,在这里我们展示一下在fastNLP中如何使用Bert进行各类任务。其中中文Bert我们使用的模型的权重来自于 中文Bert预训练 。

为了方便大家的使用,fastNLP提供了预训练的Embedding权重及数据集的自动下载,支持自动下载的Embedding和数据集见 数据集 。或您可从 使用Embedding模块将文本转成向量 与 使用Loader和Pipe加载并处理数据集 了解更多相关信息。

中文任务

下面我们将介绍通过使用Bert来进行文本分类, 中文命名实体识别, 文本匹配, 中文问答。

注解

本教程必须使用 GPU 进行实验,并且会花费大量的时间

1. 使用Bert进行文本分类

文本分类是指给定一段文字,判定其所属的类别。例如下面的文本情感分类

1, 商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!

这里我们使用fastNLP提供自动下载的微博分类进行测试

from fastNLP.io import WeiboSenti100kPipe
from fastNLP.embeddings import BertEmbedding
from fastNLP.models import BertForSequenceClassification
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
import torch

data_bundle =WeiboSenti100kPipe().process_from_file()
data_bundle.rename_field('chars', 'words')

# 载入BertEmbedding
embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)

# 载入模型
model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))

# 训练模型
device = 0 if torch.cuda.is_available() else 'cpu'
trainer = Trainer(data_bundle.get_dataset('train'), model,
                  optimizer=Adam(model_params=model.parameters(), lr=2e-5),
                  loss=CrossEntropyLoss(), device=device,
                  batch_size=8, dev_data=data_bundle.get_dataset('dev'),
                  metrics=AccuracyMetric(), n_epochs=2, print_every=1)
trainer.train()

# 测试结果
from fastNLP import Tester

tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric())
tester.test()

输出结果:

In Epoch:1/Step:12499, got best dev performance:
AccuracyMetric: acc=0.9838
Reloaded the best model.
Evaluate data in 63.84 seconds!
[tester]
AccuracyMetric: acc=0.9815

2. 使用Bert进行命名实体识别

命名实体识别是给定一句话,标记出其中的实体。一般序列标注的任务都使用conll格式,conll格式是至一行中通过制表符分隔不同的内容,使用空行分隔 两句话,例如下面的例子

中   B-ORG
共   I-ORG
中   I-ORG
央   I-ORG
致   O
中   B-ORG
国   I-ORG
致   I-ORG
公   I-ORG
党   I-ORG
十   I-ORG
一   I-ORG
大   I-ORG
的   O
贺   O
词   O

这部分内容请参考 序列标注

3. 使用Bert进行文本匹配

文本匹配任务是指给定两句话判断他们的关系。比如,给定两句话判断前一句是否和后一句具有因果关系或是否是矛盾关系;或者给定两句话判断两句话是否 具有相同的意思。这里我们使用

from fastNLP.io import CNXNLIBertPipe
from fastNLP.embeddings import BertEmbedding
from fastNLP.models import BertForSentenceMatching
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
from fastNLP.core.optimizer import AdamW
from fastNLP.core.callback import WarmupCallback
from fastNLP import Tester
import torch

data_bundle = CNXNLIBertPipe().process_from_file()
data_bundle.rename_field('chars', 'words')
print(data_bundle)

# 载入BertEmbedding
embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)

# 载入模型
model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))

# 训练模型
callbacks = [WarmupCallback(warmup=0.1, schedule='linear'), ]
device = 0 if torch.cuda.is_available() else 'cpu'
trainer = Trainer(data_bundle.get_dataset('train'), model,
                  optimizer=AdamW(params=model.parameters(), lr=4e-5),
                  loss=CrossEntropyLoss(), device=device,
                  batch_size=8, dev_data=data_bundle.get_dataset('dev'),
                  metrics=AccuracyMetric(), n_epochs=5, print_every=1,
                  update_every=8, callbacks=callbacks)
trainer.train()

tester = Tester(data_bundle.get_dataset('test'), model, batch_size=8, metrics=AccuracyMetric())
tester.test()

运行结果:

In Epoch:3/Step:73632, got best dev performance:
AccuracyMetric: acc=0.781928
Reloaded the best model.
Evaluate data in 18.54 seconds!
[tester]
AccuracyMetric: acc=0.783633

4. 使用Bert进行中文问答

问答任务是给定一段内容,以及一个问题,需要从这段内容中找到答案。例如:

"context": "锣鼓经是大陆传统器乐及戏曲里面常用的打击乐记谱方法,以中文字的声音模拟敲击乐的声音,纪录打击乐的各种不同的演奏方法。常
用的节奏型称为「锣鼓点」。而锣鼓是戏曲节奏的支柱,除了加强演员身段动作的节奏感,也作为音乐的引子和尾声,提示音乐的板式和速度,以及
作为唱腔和念白的伴奏,令诗句的韵律更加抑扬顿锉,段落分明。锣鼓的运用有约定俗成的程式,依照角色行当的身份、性格、情绪以及环境,配合
相应的锣鼓点。锣鼓亦可以模仿大自然的音响效果,如雷电、波浪等等。戏曲锣鼓所运用的敲击乐器主要分为鼓、锣、钹和板四类型:鼓类包括有单
皮鼓(板鼓)、大鼓、大堂鼓(唐鼓)、小堂鼓、怀鼓、花盆鼓等;锣类有大锣、小锣(手锣)、钲锣、筛锣、马锣、镗锣、云锣;钹类有铙钹、大
钹、小钹、水钹、齐钹、镲钹、铰子、碰钟等;打拍子用的檀板、木鱼、梆子等。因为京剧的锣鼓通常由四位乐师负责,又称为四大件,领奏的师
傅称为:「鼓佬」,其职责有如西方乐队的指挥,负责控制速度以及利用各种手势提示乐师演奏不同的锣鼓点。粤剧吸收了部份京剧的锣鼓,但以木鱼
和沙的代替了京剧的板和鼓,作为打拍子的主要乐器。以下是京剧、昆剧和粤剧锣鼓中乐器对应的口诀用字:",
"question": "锣鼓经是什么?",
"answers": [
    {
      "text": "大陆传统器乐及戏曲里面常用的打击乐记谱方法",
      "answer_start": 4
    },
    {
      "text": "大陆传统器乐及戏曲里面常用的打击乐记谱方法",
      "answer_start": 4
    },
    {
      "text": "大陆传统器乐及戏曲里面常用的打击乐记谱方法",
      "answer_start": 4
    }
]

您可以通过以下的代码训练 (原文代码:CMRC2018)

from fastNLP.embeddings import BertEmbedding
from fastNLP.models import BertForQuestionAnswering
from fastNLP.core.losses import CMRC2018Loss
from fastNLP.core.metrics import CMRC2018Metric
from fastNLP.io.pipe.qa import CMRC2018BertPipe
from fastNLP import Trainer, BucketSampler
from fastNLP import WarmupCallback, GradientClipCallback
from fastNLP.core.optimizer import AdamW
import torch

data_bundle = CMRC2018BertPipe().process_from_file()
data_bundle.rename_field('chars', 'words')

print(data_bundle)

embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn', requires_grad=True, include_cls_sep=False, auto_truncate=True,
                      dropout=0.5, word_dropout=0.01)
model = BertForQuestionAnswering(embed)
loss = CMRC2018Loss()
metric = CMRC2018Metric()

wm_callback = WarmupCallback(schedule='linear')
gc_callback = GradientClipCallback(clip_value=1, clip_type='norm')
callbacks = [wm_callback, gc_callback]

optimizer = AdamW(model.parameters(), lr=5e-5)

device = 0 if torch.cuda.is_available() else 'cpu'
trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,
                  sampler=BucketSampler(seq_len_field_name='context_len'),
                  dev_data=data_bundle.get_dataset('dev'), metrics=metric,
                  callbacks=callbacks, device=device, batch_size=6, num_workers=2, n_epochs=2, print_every=1,
                  test_use_tqdm=False, update_every=10)
trainer.train(load_best_model=False)

训练结果(和原论文中报道的基本一致):

In Epoch:2/Step:1692, got best dev performance:
CMRC2018Metric: f1=85.61, em=66.08

阅读过本文的人还看了以下文章:

【全套视频课】最全的目标检测算法系列讲解,通俗易懂!

《美团机器学习实践》_美团算法团队.pdf

《深度学习入门:基于Python的理论与实现》高清中文PDF+源码

python就业班学习视频,从入门到实战项目

2019最新《PyTorch自然语言处理》英、中文版PDF+源码

《21个项目玩转深度学习:基于TensorFlow的实践详解》完整版PDF+附书代码

《深度学习之pytorch》pdf+附书源码

PyTorch深度学习快速实战入门《pytorch-handbook》

【下载】豆瓣评分8.1,《机器学习实战:基于Scikit-Learn和TensorFlow》

《Python数据分析与挖掘实战》PDF+完整源码

汽车行业完整知识图谱项目实战视频(全23课)

李沐大神开源《动手学深度学习》,加州伯克利深度学习(2019春)教材

笔记、代码清晰易懂!李航《统计学习方法》最新资源全套!

《神经网络与深度学习》最新2018版中英PDF+源码

将机器学习模型部署为REST API

FashionAI服装属性标签图像识别Top1-5方案分享

重要开源!CNN-RNN-CTC 实现手写汉字识别

yolo3 检测出图像中的不规则汉字

同样是机器学习算法工程师,你的面试为什么过不了?

前海征信大数据算法:风险概率预测

【Keras】完整实现‘交通标志’分类、‘票据’分类两个项目,让你掌握深度学习图像分类

VGG16迁移学习,实现医学图像识别分类工程项目

特征工程(一)

特征工程(二) :文本数据的展开、过滤和分块

特征工程(三):特征缩放,从词袋到 TF-IDF

特征工程(四): 类别特征

特征工程(五): PCA 降维

特征工程(六): 非线性特征提取和模型堆叠

特征工程(七):图像特征提取和深度学习

如何利用全新的决策树集成级联结构gcForest做特征工程并打分?

Machine Learning Yearning 中文翻译稿

蚂蚁金服2018秋招-算法工程师(共四面)通过

全球AI挑战-场景分类的比赛源码(多模型融合)

斯坦福CS230官方指南:CNN、RNN及使用技巧速查(打印收藏)

python+flask搭建CNN在线识别手写中文网站

中科院Kaggle全球文本匹配竞赛华人第1名团队-深度学习与特征工程

不断更新资源

深度学习、机器学习、数据分析、python

 搜索公众号添加: datayx  

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值