天池NLP新闻文本分类学习赛心得-Task6
赛题链接:https://tianchi.aliyun.com/competition/entrance/531810/introduction
这一次task使用的深度模型为 BERT,模型强悍到不用太深层的了解NLP原理就能够有着良好的训练结果,对于这一次新闻文本分类学习赛亦是如此。
BERT原理:
1. 特征提取器
Transformer Encoder,特征提取器,由Nx个完全一样的layer组成,每个layer有2个sub-layer,分别是:Multi-Head Self-Attention机制、Position-Wise全连接前向神经网络。对于每个sub-layer,都添加了2个操作:残差连接Residual Connection和归一化Normalization,用公式来表示sub-layer的输出结果就是LayerNorm(x+Sublayer(x))。
Attention Mechanism:对于输入Input,有相应的向量query和key-value对,通过计算query和key关系的function,赋予每个value不同的权重,最终得到一个正确的向量输出Output。在Transformer编码器里,应用了两个Attention单元:Scaled Dot-Product Attention和Multi-Head Attention。
- Scaled Dot-Product Attention: Self-Attention机制是在该单元实现的。对于输入Input,通过线性变换得到Q、K、V,然后将Q和K通过Dot-Product相乘计算,得到输入Input中词与词之间的依赖关系,再通过尺度变换Scale、掩码Mask和Softmax操作,得到Self-Attention矩阵,最后跟V进行Dot-Product相乘计算。
- Multi-Head Attention: 通过h个不同线性变换,将d_model维的Q、K、V分别映射成d_k、d_k、d_v维,并行应用Self-Attention机制,得到h个d_v维的输出,进行拼接计算Concat、线性变换Linear操作。
2. 输入特征处理
BERT的输入是一个线性序列,支持单句文本和句对文本,句首用符号[CLS]表示,句尾用符号[SEP]表示,如果是句对,句子之间添加符号[SEP]。输入特征,由Token向量、Segment向量和Position向量三个共同组成,分别代表单词信息、句子信息、位置信息。
3. 预训练
BERT采用了MLM和NSP两种策略用于模型预训练。
- MLM: (Masked LM)对输入的单词序列,随机地掩盖15%的单词,然后对掩盖的单词做预测任务。相比传统标准条件语言模型只能left-to-right或right-to-left单向预测目标函数,MLM可以从任意方向预测被掩盖的单词。
- NSP: (Next Sentence Prediction)许多重要的下游任务譬如QA、NLI需要语言模型理解两个句子之间的关系,而传统的语言模型在训练的过程没有考虑句对关系的学习。
4. 任务微调
BERT提供了4种不同下游任务的微调方案:
(a)句对关系判断,第一个起始符号[CLS]经过Transformer编码器后,增加简单的Softmax层,即可用于分类;
(b)单句分类任务,具体实现同(a)一样;
(c)问答类任务,譬如SQuAD v1.1,问答系统输入文本序列的question和包含answer的段落,并在序列中标记answer,让BERT模型学习标记answer开始和结束的向量来训练模型;
(d)序列标准任务,譬如命名实体标注NER,识别系统输入标记好实体类别(人、组织、位置、其他无名实体)的文本序列进行微调训练,识别实体类别时,将序列的每个Token向量送到预测NER标签的分类层进行识别。
BERT模型部分代码:
# build model
class Model(nn.Module):
def __init__(self, vocab):
super(Model, self).__init__()
self.sent_rep_size = 256
self.doc_rep_size = sent_hidden_size * 2
self.all_parameters = {}
parameters = []
self.word_encoder = WordBertEncoder()
bert_parameters = self.word_encoder.get_bert_parameters()
self.sent_encoder = SentEncoder(self.sent_rep_size)
self.sent_attention = Attention(self.doc_rep_size)
parameters.extend(list(filter(lambda p: p.requires_grad, self.sent_encoder.parameters())))
parameters.extend(list(filter(lambda p: p.requires_grad, self.sent_attention.parameters())))
self.out = nn.Linear(self.doc_rep_size, vocab.label_size, bias=True)
parameters.extend(list(filter(lambda p: p.requires_grad, self.out.parameters())))
if use_cuda:
self.to(device)
if len(parameters) > 0:
self.all_parameters["basic_parameters"] = parameters
self.all_parameters["bert_parameters"] = bert_parameters
logging.info('Build model with bert word encoder, lstm sent encoder.')
para_num = sum([np.prod(list(p.size())) for p in self.parameters()])
logging.info('Model param num: %.2f M.' % (para_num / 1e6))
def forward(self, batch_inputs):
# batch_inputs(batch_inputs1, batch_inputs2): b x doc_len x sent_len
# batch_masks : b x doc_len x sent_len
batch_inputs1, batch_inputs2, batch_masks = batch_inputs
batch_size, max_doc_len, max_sent_len = batch_inputs1.shape[0], batch_inputs1.shape[1], batch_inputs1.shape[2]
batch_inputs1 = batch_inputs1.view(batch_size * max_doc_len, max_sent_len) # sen_num x sent_len
batch_inputs2 = batch_inputs2.view(batch_size * max_doc_len, max_sent_len) # sen_num x sent_len
batch_masks = batch_masks.view(batch_size * max_doc_len, max_sent_len) # sen_num x sent_len
sent_reps = self.word_encoder(batch_inputs1, batch_inputs2) # sen_num x sent_rep_size
sent_reps = sent_reps.view(batch_size, max_doc_len, self.sent_rep_size) # b x doc_len x sent_rep_size
batch_masks = batch_masks.view(batch_size, max_doc_len, max_sent_len) # b x doc_len x max_sent_len
sent_masks = batch_masks.bool().any(2).float() # b x doc_len
sent_hiddens = self.sent_encoder(sent_reps, sent_masks) # b x doc_len x doc_rep_size
doc_reps, atten_scores = self.sent_attention(sent_hiddens, sent_masks) # b x doc_rep_size
batch_outputs = self.out(doc_reps) # b x num_labels
return batch_outputs
model = Model(vocab)
一点心得:
在了解原理后再来构建模型,由于也是头一次参与到这种赛事,许多知都不懂,模型代码也是参考官方给的代码,但是将这些代码运行在我本地的破电脑上,显存不足(汗),没办法,不能享受BERT的魅力了,我也看到有一些小伙伴把代码跑出来了,羡慕呐。
NLP平安、人工智能平安、python平安。。。