import tensorflow as tf
def beam_search2(batch_size, beam_width, vocab_size, bos_id, eos_id, inst, max_length):
x_placeholder = tf.placeholder(dtype=tf.int32, shape=[batch_size, None], name='x')
x_sequence_length_placeholder = tf.placeholder(dtype=tf.int32, shape=[batch_size, ])
x_mask = tf.sequence_mask(lengths=x_sequence_length_placeholder, dtype=tf.int32)
memory = inst.encode(x_input=x_placeholder, x_mask=x_mask)
memory_mask = x_mask
finished = tf.constant(value=1, shape=[batch_size, 1], dtype=tf.int32)
input_ta = tf.TensorArray(dtype=tf.int32,
dynamic_size=True,
clear_after_read=True,
element_shape=[batch_size, None])
def cond_fn(time_, finished, input_ta, pre_score):
return tf.logical_not(tf.reduce_all(finished))
def body_fn(time_, finished, input_ta, pre_score):
def score_fn(time_, pre_score, cur_score):
return (pre_score + tf.log(x=cur_score)) * 1/time_**0.7
# B*b x T
y_input = input_ta.read(time_)
sequence_length = [time_ for i in range(batch_size * beam_width)]
y_mask = tf.sequence_mask(lengths=sequence_length, dtype=tf.int32)
# B*b X T x V
logits, scores = inst.decode(y_input, y_mask, memory, memory_mask)
# B*b X V
pre_score = score_fn(time_, scores[:, time_, :], pre_score)
scores_next = tf.reshape(tensor=pre_score, shape=[batch_size, beam_width * vocab_size])
top_k_output = tf.nn.top_k(input=scores_next, k=beam_width)
# B x b
next_ids = tf.mod(x=top_k_output.indices, y=vocab_size)
next_ids = tf.reshape(tensor=next_ids, shape=[batch_size * beam_width, 1])
finished = tf.equal(x=next_ids, y=eos_id)
bos_ids = tf.constant(value=bos_id, shape=[batch_size * beam_width, 1])
next_ids = tf.concat(values=[bos_ids, y_input, next_ids], axis=-1)
input_ta = input_ta.write(time_ + 1, next_ids)
time_ = time_ + 1
return time_, finished, input_ta, pre_score
time_ = 0
res = tf.while_loop(cond=cond_fn, body=body_fn, loop_vars=[time_, finished, input_ta], maximum_iterations=max_length)
return res
class Transformer:
def __init__(self, batch_size, vocab_size, bos_id, eos_id, max_length):
self.params["batch_size"] = batch_size
self.params["vocab_size"] = vocab_size
self.params["eos_id"] = eos_id
self.params["bos_id"] = bos_id
self.params["max_length"] = max_length
def get_config(self):
return self.params
def encode(self, x_input, x_mask):
# input: B x T
# output: B x T x V
shape = [self.params["batch_size"], tf.shape(input=x_input)[-1], self.params["vocab_size"]]
x_score = tf.random_uniform(shape=shape, minval=0.0, maxval=1.0, dtype=tf.float32)
x_logit = x_score * 5.0
return x_score, x_logit
def decode(self, y_input, y_mask, memory, memory_mask):
# input: B x T
# output: B x T x V
shape = [self.params["batch_size"], tf.shape(input=y_input)[-1], self.params["vocab_size"]]
x_score = tf.random_uniform(shape=shape, minval=0.0, maxval=1.0, dtype=tf.float32)
x_logit = x_score * 5.0
return x_score, x_logit
if __name__=="""__main__""":
pass
【无标题】
最新推荐文章于 2024-07-12 23:11:20 发布