import numpy as np
import tensorflow as tf
import os
import json
start_token = 'B'
end_token = 'E'
model_dir = './model'
def process_poems():
poems_vector = np.load('poems_index.npy', allow_pickle=True)
with open('word2int.json','r') as f:
word_int_map = json.load(f)
with open('int2word.json','r') as f:
int_word_map = json.load(f)
return poems_vector, word_int_map, int_word_map
def rnn_model(model,
input_data,
output_data,
vocab_size,
rnn_size=128,
batch_size=64,
learning_rate=0.01):
end_points = {}
embedding = tf.get_variable('embedding',
initializer=tf.random_uniform(
[vocab_size + 1, rnn_size], -1.0, 1.0))
inputs = tf.nn.embedding_lookup(embedding, input_data)
if model == 'rnn':
cell = tf.contrib.rnn.BasicRNNCell(rnn_size)
elif model == 'lstm':
cell = tf.contrib.rnn.BasicLSTMCell(rnn_size, state_is_tuple=True)
elif model == 'gru':
cell = tf.contrib.rnn.GRUCell(rnn_size)
if output_data is not None:
initial_state = cell.zero_state(batch_size, tf.float32)# 将LSTM中的状态初始化为全0数组
else:
initial_state = cell.zero_state(1, tf.float32)
outputs, last_state = tf.nn.dynamic_rnn(cell,
inputs,
initial_state=initial_state)
outputs = tf.reshape(outputs, [-1, rnn_size])
weights = tf.Variable(tf.truncated_normal([rnn_size, vocab_size +1]))
bias = tf.Variable(tf.zeros(shape=[vocab_size + 1]))
logits = tf.nn.bias_add(tf.matmul(outputs, weights), bias=bias)
if output_data is not None:
labels = tf.one_hot(tf.reshape(output_data, [-1]),
depth=vocab_size + 1)
loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels,
logits=logits)
total_loss = tf.reduce_mean(loss)
train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)
end_points['initial_state'] = initial_state
end_points['train_op'] = train_op
end_points['total_loss'] = total_loss
end_points['last_state'] = last_state
else:
prediction = tf.nn.softmax(logits)
end_points['initial_state'] = initial_state
end_points['last_state'] = last_state
end_points['prediction'] = prediction
return end_points
def run_training():
if not os.path.exists(model_dir):
os.makedirs(model_dir)
poems_vector, word_to_int, vocabularies = process_poems()
ds = tf.data.Dataset.from_generator(lambda:[ins for ins in poems_vector],
tf.int32,
tf.TensorShape([None]))
ds = ds.shuffle(buffer_size=poems_vector.shape[0])
ds = ds.repeat()
ds = ds.padded_batch(64,
padded_shapes=tf.TensorShape([None]),
padding_values=word_to_int[' '])
ds = ds.map(lambda x: (x[:, :-1], x[:, 1:]))
iterator = ds.make_initializable_iterator()
Xs, Ys = iterator.get_next()
input_data = tf.placeholder(tf.int32, [64, None])
output_targets = tf.placeholder(tf.int32, [64, None])
end_points = rnn_model(model='lstm',
input_data=input_data,
output_data=output_targets,
vocab_size=len(vocabularies))
saver = tf.train.Saver(tf.global_variables())
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)
sess.run(iterator.initializer)
for epoch in range(50):
for batch in range(len(word_to_int)//64):
batch_xs, batch_ys = sess.run([Xs, Ys])
loss, _=sess.run(
[end_points['total_loss'], end_points['train_op']],
feed_dict={
input_data: batch_xs,
output_targets: batch_ys
})
print('Epoch: %d, batch: %d, training loss: %.6f'%(epoch,batch,loss))
print('Epoch: %d, training loss: %.6f'%(epoch, loss))
if epoch%10 == 0:
saver.save(sess,
os.path.join(model_dir,"poems"),
global_step=epoch)
def prediction_to_word(predict, vocabs):
predict = predict[0]
predict /= np.sum(predict)
sample = np.random.choice(np.arange(len(predict)), p=predict)
if sample > len(vocabs):
return ' '
else:
return vocabs[str(sample)]
def gen_poem(begin_word):
batch_size = 1
poems_vector, word_int_map, vocabularies = process_poems()
input_data = tf.placeholder(tf.int32, [batch_size, None])
end_points = rnn_model(model='lstm',
input_data=input_data,
output_data=None,
vocab_size=len(vocabularies),
learning_rate=0.0002)
saver = tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
checkpoint = tf.train.latest_checkpoint(model_dir)
saver.restore(sess, checkpoint)
x = np.array([list(map(word_int_map.get, start_token))])
[predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], feed_dict={input_data:x})
word = begin_word or prediction_to_word(predict, vocabularies)
poem_ = ''
i = 0
while word != end_token:
poem_ += word
i += 1
if i > 24:
break
x = np.array([[word_int_map[word]]])
[predict, last_state] = sess.run(
[end_points['prediction'], end_points['last_state']],
feed_dict={input_data:x,
end_points['initial_state']:last_state
})
word = prediction_to_word(predict, vocabularies)
return poem_
def pretty_print_poem(poem_):
poem_sentences = poem_.split('。')
for s in poem_sentences:
if s != '' and len(s) > 10:
print(s + '。')
is_training = True
def main():
if is_training:
run_training()
else:
begin_char = input('## (输入 quit 退出)请输入第一个字 please input the first character: ')
if begin_char == 'quit':
exit()
poem = gen_poem(begin_char)
pretty_print_poem(poem_=poem)
if __name__=='__main__':
main()
tensorflow 实现 RNN 实验
最新推荐文章于 2024-05-05 20:10:01 发布