GPT BERT等模型如何添加新的token

废话不多说,直接上code

def add_special_tokens(model, tokenizer):
    orig_num_tokens = len(tokenizer.vocab)
    num_add_tokens = tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN)
    if num_add_tokens > 0:
        model.resize_token_embeddings(new_num_tokens=orig_num_tokens + num_add_tokens)

这个方法是借助huggingface的transformer库进行实现,其中model可以为huggingface支持的任何一个模型,如bert,gpt,robert等,tokenizer可以为BertTokenizer GPT2Tokenizer 等。

下面看看是如何进行添加的。

第一步:添加到词表,指定对应的index

可以用 

new_tokens = ['token1', 'token2'] 
tokenizer.add_tokens(new_tokens)

或者

special_tokens = {'additional_special_tokens':['token1', 'token2'] }
tokenizer.add_special_tokens(special_tokens)

第二步:对模型token embedding 矩阵进行修改,大小由(voc_size, emd_size)改为添加新词后的大小(voc_size+new_token_num, emd_size),具体实现见以下代码

def _get_resized_embeddings(self, old_embeddings: nn.Embedding, new_num_tokens: Optional[int] = None)

    if new_num_tokens is None:
        return old_embeddings

    old_num_tokens, old_embedding_dim = old_embeddings.weight.size() #获取原来token数量与 embedding 维度

    if old_num_tokens == new_num_tokens:
        return old_embeddings

    # Build new embeddings
    new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)  #根据当前token个数初始化新的token embedding 矩阵
    new_embeddings.to(old_embeddings.weight.device, dtype=old_embeddings.weight.dtype)

    # initialize all new embeddings (in particular added tokens)
    self._init_weights(new_embeddings) # 对新的token embedding 矩阵随机初始化权重

    # Copy token embeddings from the previous weights

    # numbers of tokens to copy
    n = min(old_num_tokens, new_num_tokens)
    new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] # 将原来token embedding参数copy到新的token embedding矩阵对应位置

    return new_embeddings

这样后面就可以对新的token进行编解码了

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值