NEZHA(中文版GPT2)实现Beam-search Tensorflow1.15 (一)

平时一直使用Transformers的包来调用GPT2或者生成模型,其generate函数封装了top-k和top-p以及beam-search函数。最近在使用华为的NEZHA(中文GPT2),这个模型的生成模型是用Tensorflow1.15写的。它的sample方式是top-k和top-p的sampling,最近需要使用beam-search,因此根据NEZHA的代码来实现beam-search。

在输入到模型之前,需要对输入copy为beam size的大小,相当于是在生成的时候维护beam size大小个句子:

predict = init_input_ids
predict = tf.tile(predict, [beam_width, 1])

首先,将copy后的predict输入到模型中,step为通过NEZHA得到了预测出的输出logits, 保存中间结果(transformer的k,v等值,用于加速decoding的速度)的past,以及当前处于生成的第几步predict_token_idx(初始值为0,每运行一次加一)

predict_logits, past, predict_token_idx = step(predict, past=past,     
                                predict_token_idx=predict_token_idx)

将输出的概率通过log softmax,之所以要加上log,是因为beam-search在选择的时候是通过语言模型的概率累乘,而加上log后把累乘转换成累加,方便计算。

log_probs = tf.nn.log_softmax(predict_logits, axis=-1)

之后需要分情况,如果是第一步生成,由于原先的predict中的beam width个句子都是一样,因此predict_logits的对于第一步来说,只需要选predict_logits(beam width, vocab size)的其中一个predict_logits[0],从里面挑选出top-k(k为beam width)个选项,将这topk个的log probs作为初始状态的score(sel_sum_logprobs)

sel_sum_logprobs, indices = tf.nn.top_k(tf.expand_dims(log_probs[0], 0), k=beam_width)

第一步得到的beam width个候选,将其赋值给predicted_tokens。

predicted_tokens = tf.expand_dims(tf.identity(predict), 2)

第一步初始化就结束了,之后每一步就根据初始化的结果进行操作,对于接下来每一步,需要将每一步生成的预测的输出概率log_probs和之前的累加的概率sel_sum_logprobs相加:

predict_logits, past, predict_token_idx = step(predict, past=past, 
                                     predict_token_idx=predict_token_idx)
log_probs = tf.nn.log_softmax(predict_logits, axis=-1)
sum_logprobs = (tf.expand_dims(sel_sum_logprobs, axis=2) + (log_probs))

从累加后的概率中选出,top-k个候选,在top-k之前需要将sum_logprobs的大小进行转换((batch, beam width, vocab_size) -->(batch, beam width * vocab_size) 。其原因是因为要从所有beam的候选词中进行选择

sel_sum_logprobs, indices = tf.nn.top_k(
            tf.reshape(sum_logprobs, [batch_size, vocab_size* beam_width]),
            k=beam_width)

 其得到的indices是beam width所有的indices,需要知道它在哪一个beam,以及在beam内部中的indice

predict = indices % vocab_size
beam_ids = indices // vocab_size

对于每一个预测的词predict,要判断它是属于哪一个beam id,并把它和beam id对应的已经生成的predicted_tokens拼接,同时past也需要根据beam id更新的past,其原因是得到了新的top-k个predict token,其对应的已经生成的predicted tokens不一定相同,因此需要将past进行更新。

predicted_tokens = tf.concat([batch_gather(predicted_tokens, beam_ids), tf.expand_dims(predict, 2)], axis=2)

past = tf.gather(past, tf.squeeze(beam_ids), axis=2)

这样经过若干轮后,得到的predicted_tokens就是得到色sequence。

拿“一时间风云变幻”作为输入,一般来说beam-search的效果会比greedy的方法要好一些。但是感觉重复现象很严重,这里也实现了repetition penality=1.2,其重复现象会缓解很多。repetition penality的实现以及对于在tensorflow1.15中优化代码放在后篇之后再写

方法句子
greedy一时间风云变幻,各种各样的新闻事件层出不穷,各种各样的人物也层出不穷,各种各样的人物也层出不穷,各种各样的人物也层
beam-search一时间风云变幻,各种各样的传闻纷至沓来。一时间,各种各样的传闻纷至沓来。一时间,各种各样的传闻纷至沓来。一时间,各种

beam-search

(r_p=1.2)

一时间风云变幻,各种各样的传闻纷至沓来。在这种情况下,我们不得不面对一个现实,那就是中国的网络游戏市场已经进入了一个

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值