RoBERTa相比BERT的改进

继BERT、XLNet之后,Facebook提出的RoBERTa(a Robustly Optimized BERT Pretraining Approach)。本篇文章主要总结下RoBERTa相比于BERT的改进。
RoBERTa在模型结构层面没有改变,改变的只是预训练的方法,具体是以下三点。
1.动态mask
RoBERTa把预训练的数据复制10份,每一份都随机选择15%的Tokens进行mask,也就是说,同样的一句话有10种不同的mask方式。然后每份数据都训练N/10个epoch。这就相当于在这N个epoch的训练中,每个序列的被mask的tokens是会变化的。

BERT随机选择15%的tokens替换,为了消除上下游任务的不匹配问题,对这15%的tokens进行:(1)80%的时间替换为[MASK] (2)10%的时间不变 (3)10%的时间替换为其他词

在这里插入图片描述

2.去掉NSP任务
RoBERTa去掉了NSP任务,使用FULL-SENTENCES训练方式。每次输入连续的多个句子,直到最大长度512。
BERT使用了NSP任务,对于输入的两个句子判断是否为连续的。训练数据的构成为,50%的样本是同一篇文章的上下句,50%的样本是不同文章的两句话。
3.更大的mini-batch,更多的数据
(1)RoBERTa使用的batch size为8k,BERT使用的batch size是256。
(2)RoBERTa使用的数据量为160G,BERT使用的数据量为13G
在这里插入图片描述

"""
bert mask的事例代码
"""
import random as rng
import collections
tokens = ["I","like","you","I","like","you","I","like","you","I","like","you","I","like","you","I","like","you","I","like","you"]
max_predictions_per_seq = 100
masked_lm_prob = 0.15
vocab_words = {0:"X",1:"1",2:"2",3:"3",4:"4",5:"5",6:"6",7:"7",8:"8",9:"9",10:"10"}
cand_indexes = []

for (i, token) in enumerate(tokens):
    if token == "[CLS]" or token == "[SEP]":
      continue
    cand_indexes.append(i)

rng.shuffle(cand_indexes)
masked_lm = collections.namedtuple("masked_lm", ["index", "label"])
# mask的长度
num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob))))
print("num_to_predict:", num_to_predict)
masked_lms = []
covered_indexes = set()
output_tokens = list(tokens)

# 取15%的长度
for index in cand_indexes:
    if len(masked_lms) >= num_to_predict:
        break
    if index in covered_indexes:
        continue
    covered_indexes.add(index)
    masked_token = None
    
    # 在15%中选80%替换为mask
    if rng.random() < 0.8:
        masked_token = "[MASK]"
    else:
        if rng.random() > 0.5:
        # 在15%中,20%*50%=10%的概率为不变
            masked_token = tokens[index]
        else:
        # 20%*50%=10%的概率随机替换
            masked_token = vocab_words[rng.randint(0,len(vocab_words)-1)]
    output_tokens[index] = masked_token
    masked_lms.append(masked_lm(index=index, label=tokens[index]))
print("output_tokens:\n",output_tokens)        
print("masked_lms:\n", masked_lms)
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值