tensorflow 实现 RNN 实验

import numpy as np
import tensorflow as tf
import os
import json

start_token = 'B'
end_token = 'E'
model_dir = './model' 

def process_poems():
    poems_vector = np.load('poems_index.npy', allow_pickle=True)
    with open('word2int.json','r') as f:
        word_int_map = json.load(f)
    with open('int2word.json','r') as f:
        int_word_map = json.load(f)
    return poems_vector, word_int_map, int_word_map

def rnn_model(model,
              input_data,
              output_data,
              vocab_size,
              rnn_size=128,
              batch_size=64,
              learning_rate=0.01):
    end_points = {}
    embedding = tf.get_variable('embedding',
                                initializer=tf.random_uniform(
                                [vocab_size + 1, rnn_size], -1.0, 1.0))
    
    inputs = tf.nn.embedding_lookup(embedding, input_data)
    if model == 'rnn':
        cell = tf.contrib.rnn.BasicRNNCell(rnn_size)
    elif model == 'lstm':
        cell = tf.contrib.rnn.BasicLSTMCell(rnn_size, state_is_tuple=True)
    elif model == 'gru':
        cell = tf.contrib.rnn.GRUCell(rnn_size)

    if output_data is not None:
        initial_state = cell.zero_state(batch_size, tf.float32)# 将LSTM中的状态初始化为全0数组
    else:
        initial_state = cell.zero_state(1, tf.float32)
    outputs, last_state = tf.nn.dynamic_rnn(cell,
                                            inputs,
                                            initial_state=initial_state)
    outputs = tf.reshape(outputs, [-1, rnn_size])
    weights = tf.Variable(tf.truncated_normal([rnn_size, vocab_size +1]))
    bias = tf.Variable(tf.zeros(shape=[vocab_size + 1]))
    logits = tf.nn.bias_add(tf.matmul(outputs, weights), bias=bias)

    if output_data is not None:
        labels = tf.one_hot(tf.reshape(output_data, [-1]),
                            depth=vocab_size + 1)
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels,
                                                       logits=logits)
        total_loss = tf.reduce_mean(loss)
        train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)
        end_points['initial_state'] = initial_state
        end_points['train_op'] = train_op
        end_points['total_loss'] = total_loss
        end_points['last_state'] = last_state
    else:
        prediction = tf.nn.softmax(logits)
        end_points['initial_state'] = initial_state
        end_points['last_state'] = last_state
        end_points['prediction'] = prediction
    
    return end_points

def run_training():
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    poems_vector, word_to_int, vocabularies = process_poems()
    ds = tf.data.Dataset.from_generator(lambda:[ins for ins in poems_vector],
                                        tf.int32,
                                        tf.TensorShape([None]))
    ds = ds.shuffle(buffer_size=poems_vector.shape[0])
    ds = ds.repeat()
    ds = ds.padded_batch(64,
                         padded_shapes=tf.TensorShape([None]),
                         padding_values=word_to_int[' '])
    ds = ds.map(lambda x: (x[:, :-1], x[:, 1:]))
    iterator = ds.make_initializable_iterator()
    Xs, Ys = iterator.get_next()
    input_data = tf.placeholder(tf.int32, [64, None])
    output_targets = tf.placeholder(tf.int32, [64, None])
    end_points = rnn_model(model='lstm',
                           input_data=input_data,
                           output_data=output_targets,
                           vocab_size=len(vocabularies))
    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)
        sess.run(iterator.initializer)
        for epoch in range(50):
            for batch in range(len(word_to_int)//64):
                batch_xs, batch_ys = sess.run([Xs, Ys])
                loss, _=sess.run(
                    [end_points['total_loss'], end_points['train_op']],
                    feed_dict={
                        input_data: batch_xs,
                        output_targets: batch_ys
                    })
                print('Epoch: %d, batch: %d, training loss: %.6f'%(epoch,batch,loss))
            print('Epoch: %d, training loss: %.6f'%(epoch, loss))
            if epoch%10 == 0:
                saver.save(sess,
                os.path.join(model_dir,"poems"),
                global_step=epoch)

def prediction_to_word(predict, vocabs):
    predict = predict[0]
    predict /= np.sum(predict)
    sample = np.random.choice(np.arange(len(predict)), p=predict)
    if sample > len(vocabs):
        return ' '
    else:
        return vocabs[str(sample)]

def gen_poem(begin_word):
    batch_size = 1
    poems_vector, word_int_map, vocabularies = process_poems()
    input_data = tf.placeholder(tf.int32, [batch_size, None])
    end_points = rnn_model(model='lstm',
                           input_data=input_data,
                           output_data=None,
                           vocab_size=len(vocabularies),
                           learning_rate=0.0002)
    saver = tf.train.Saver(tf.global_variables())
    with tf.Session() as sess:
        checkpoint = tf.train.latest_checkpoint(model_dir)
        saver.restore(sess, checkpoint)
        x = np.array([list(map(word_int_map.get, start_token))])
        [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], feed_dict={input_data:x})
        word = begin_word or prediction_to_word(predict, vocabularies)
        poem_ = ''
        i = 0
        while word != end_token:
            poem_ += word
            i += 1
            if i > 24:
                break
            x = np.array([[word_int_map[word]]])
            [predict, last_state] = sess.run(
                [end_points['prediction'], end_points['last_state']],
                feed_dict={input_data:x,
                           end_points['initial_state']:last_state
                           })
            word = prediction_to_word(predict, vocabularies)
        return poem_

def pretty_print_poem(poem_):
    poem_sentences = poem_.split('。')
    for s in poem_sentences:
        if s != '' and len(s) > 10:
            print(s + '。')

is_training = True

def main():
    if is_training:
        run_training()
    else:
        begin_char = input('## (输入 quit 退出)请输入第一个字 please input the first character: ')
        if begin_char == 'quit':
            exit()
        poem = gen_poem(begin_char)
        pretty_print_poem(poem_=poem)

if __name__=='__main__':
main()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值