BERT 的 Embedding
- BERT 和 Transformer 在 Embedding 的时候有两个区别:
- 由于 BERT 采用了两个句子拼接后作为一个 sample 的方法,我们需要在输入的时候嵌入当前的词属于第一句话还是第二句话这个信息,因此我们需要 segment embedding
- BERT 也需要编码位置信息因为他在这一点继承了 Transformer 的 self attention 操作,但是 BERT 的位置信息既可以和 Transformer 一样采用位置编码(Positional encoding)的方法用固定的 sin, cos 函数来实现,也可以采用将位置信息输入 embedding 层让他自己学习出更好的位置表示
代码复现
"""
@Time : 2022/10/22
@Author : Peinuan qin
"""
import torch
import torch.nn as nn
from transformers import BertTokenizer,BertModel, BertConfig
from Dataset import BERTDataset
class PositionalEmbedding(nn.Embedding):
def __init__(self, d_model, max_len=512):
super(PositionalEmbedding, self).__init__(max_len, d_model, padding_idx=0)
class SegmentEmbedding(nn.Embedding):
def __init__(self, d_model, segement_num=2):
super(SegmentEmbedding, self).__init__(segement_num, d_model, padding_idx=0)
class TokenEmbedding(nn.Embedding):
def __init__(self, d_model, vocab_size,):
super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=0)
class BERTEmbedding(nn.Module):
def __init__(self, vocab_size, d_model, drop_rate=0.1):
super(BERTEmbedding, self).__init__()
self.token_embedding = TokenEmbedding(d_model, vocab_size)
self.position_embedding = PositionalEmbedding(d_model)
self.segment_embedding = SegmentEmbedding(d_model)
self.dropout = nn.Dropout(drop_rate)
def forward(self, sequence, segment_labels, position_ids):
x = self.token_embedding(sequence) + self.segment_embedding(segment_labels) + self.position_embedding(position_ids)
return self.dropout(x)
if __name__ == '__main__':
model_name = '../bert_pretrain_base/'
d_model = 768
config = BertConfig.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)
bert_embedding = BERTEmbedding(vocab_size=config.vocab_size, d_model=d_model)
dataset = BERTDataset(corpus_path="./corpus_chinese.txt"
, tokenizer=tokenizer
, seq_len=20)
sample = dataset[0]
input_ids = sample['input_ids']
segment_labels = sample['segment_labels']
position_ids = torch.tensor([i for i in range(len(input_ids))])
print(sample)
x = bert_embedding(input_ids, segment_labels, position_ids)
print(x)
- 看不懂 main 函数中前半截操作可以参考上篇文章 ,我采用了 huggingface 提供的预训练模型来提供 tokenizer 和 词表