【无标题】

import tensorflow as tf


def beam_search2(batch_size, beam_width, vocab_size, bos_id, eos_id, inst, max_length):
  
  x_placeholder = tf.placeholder(dtype=tf.int32, shape=[batch_size, None], name='x')
  x_sequence_length_placeholder = tf.placeholder(dtype=tf.int32, shape=[batch_size, ])
  x_mask = tf.sequence_mask(lengths=x_sequence_length_placeholder, dtype=tf.int32)
  
  memory = inst.encode(x_input=x_placeholder, x_mask=x_mask)
  memory_mask = x_mask
  finished = tf.constant(value=1, shape=[batch_size, 1], dtype=tf.int32)
  input_ta = tf.TensorArray(dtype=tf.int32,
                            dynamic_size=True,
                            clear_after_read=True,
                            element_shape=[batch_size, None])
  
  def cond_fn(time_, finished, input_ta, pre_score):
    return tf.logical_not(tf.reduce_all(finished))
  
  def body_fn(time_, finished, input_ta, pre_score):
    def score_fn(time_, pre_score, cur_score):
      return (pre_score + tf.log(x=cur_score)) * 1/time_**0.7
    
    # B*b x T
    y_input = input_ta.read(time_)
    sequence_length = [time_ for i in range(batch_size * beam_width)]
    y_mask = tf.sequence_mask(lengths=sequence_length, dtype=tf.int32)
    
    # B*b X T x V
    logits, scores = inst.decode(y_input, y_mask, memory, memory_mask)
    
    # B*b X V
    pre_score = score_fn(time_, scores[:, time_, :], pre_score)
    scores_next = tf.reshape(tensor=pre_score, shape=[batch_size, beam_width * vocab_size])
    top_k_output = tf.nn.top_k(input=scores_next, k=beam_width)

    # B x b
    next_ids = tf.mod(x=top_k_output.indices, y=vocab_size)
    next_ids = tf.reshape(tensor=next_ids, shape=[batch_size * beam_width, 1])
    
    finished = tf.equal(x=next_ids, y=eos_id)
    bos_ids = tf.constant(value=bos_id, shape=[batch_size * beam_width, 1])
    next_ids = tf.concat(values=[bos_ids, y_input, next_ids], axis=-1)
    input_ta = input_ta.write(time_ + 1, next_ids)
    time_ = time_ + 1
    
    return time_, finished, input_ta, pre_score
  time_ = 0
  res = tf.while_loop(cond=cond_fn, body=body_fn, loop_vars=[time_, finished, input_ta], maximum_iterations=max_length)
  return res


class Transformer:
  
  def __init__(self, batch_size, vocab_size, bos_id, eos_id, max_length):
    
    self.params["batch_size"] = batch_size
    self.params["vocab_size"] = vocab_size
    self.params["eos_id"] = eos_id
    self.params["bos_id"] = bos_id
    self.params["max_length"] = max_length
  
  def get_config(self):
    
    return self.params
  
  def encode(self, x_input, x_mask):
    # input: B x T
    # output: B x T x V
    shape = [self.params["batch_size"], tf.shape(input=x_input)[-1], self.params["vocab_size"]]
    x_score = tf.random_uniform(shape=shape, minval=0.0, maxval=1.0, dtype=tf.float32)
    x_logit = x_score * 5.0
    
    return x_score, x_logit
  
  def decode(self, y_input, y_mask, memory, memory_mask):
    # input: B x T
    # output: B x T x V
    shape = [self.params["batch_size"], tf.shape(input=y_input)[-1], self.params["vocab_size"]]
    x_score = tf.random_uniform(shape=shape, minval=0.0, maxval=1.0, dtype=tf.float32)
    x_logit = x_score * 5.0
  
    return x_score, x_logit


if __name__=="""__main__""":
  
  pass
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值