环境:transformer 2.11.0
问题描述:
在很多的nlp任务当中,我们会为类似于Bert的预训练模型填入一些在其tokenizer词表中未出现过的词,这样做的目的是为了防止这种未出现过的词在tokenizer拆分时不被拆分成别的单词。
Eg.The system as described above has its greatest application in an arrayed <e1> configuration </e1> of antenna <e2> elements </e2>.
上面例子当中,由于<e1></e1>用来标记实体1,<e2></e2>用来标记实体2,在tokenizer时是不希望被查分成其他单词的,所以就要添加一些<e1>等等这样的“特殊单词”,即从未在vocab.txt词表中出现的单词;
这就需要进行如下操作:
import from transformers import BertTokenizer,BertModel
tokenizer = BertTokenizer..from_pretrained('Bert') #括号里为bert的存储路径
model = BertModel.from_pretrained('Bert') #括号里为bert的存储路径
#add special tokens eg. <e1>
ADDITIONAL_SPECIAL_TOKENS = ["<e1>", "</e1>", "<e2>", "</e2>"]
tokenizer.add_special_tokens({"additional_special_tokens": ADDITIONAL_SPECIAL_TOKENS})
添加了特殊字符后,在运行模型时,有如下报错:
RuntimeError: index out of range: Tried to access index 30522 out of table with 30521 rows. at C:\w\1\s\windows\pytorch\aten\src\TH/generic/THTensorEvenMoreMath.cpp:418
这个错误的原因是:由于原始的vocab.txt词表一共有0~30521个单词,索引到30521,但是添加了4个特殊单词之后,词表的索引到了30525,但是模型在做word embedding时依然还是最初定义模型是的词表索引到30521,所以需要对模型的词表大小进行更新,修改如下:
import from transformers import BertTokenizer,BertModel
tokenizer = BertTokenizer..from_pretrained('Bert') #括号里为bert的存储路径
model = BertModel.from_pretrained('Bert') #括号里为bert的存储路径
#add special tokens eg. <e1>
ADDITIONAL_SPECIAL_TOKENS = ["<e1>", "</e1>", "<e2>", "</e2>"]
tokenizer.add_special_tokens({"additional_special_tokens": ADDITIONAL_SPECIAL_TOKENS})
#updata length of model token_embedding size
model.resize_token_embeddings(len(tokenizer))
下面是我查官网doc给出的样例,和我上面写出的差不多:
# Let's see how to add a new classification token to GPT-2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
special_tokens_dict = {'cls_token': '<CLS>'}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
assert tokenizer.cls_token == '<CLS>'
文档链接:点击进入transformer_tokenizer的doc