Tensorflow lstm实现的小说撰写预测

最近,在研究深度学习方面的知识,结合Tensorflow,完成了基于lstm的小说预测程序demo。

lstm是改进的RNN,具有长期记忆功能,相对于RNN,增加了多个门来控制输入与输出。原理方面的知识网上很多,在此,我只是将我短暂学习的tensorflow写一个预测小说的demo,如果有错误,还望大家指出。

1、将小说进行分词,去除空格,建立词汇表与id的字典,生成初始输入模型的x与y

def readfile(file_path):
    f = codecs.open(file_path, 'r', 'utf-8')
    alltext = f.read()
    alltext = re.sub(r'\s','', alltext)
    seglist = list(jieba.cut(alltext, cut_all = False))
    return seglist
    
def _build_vocab(filename):
    data = readfile(filename)
    counter = collections.Counter(data)
    count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))


    words, _ = list(zip(*count_pairs))
    word_to_id = dict(zip(words, range(len(words))))
    id_to_word = dict(zip(range(len(words)),words))
    dataids = []
    for w in data:
        dataids.append(word_to_id[w])
    return word_to_id, id_to_word,dataids


def dataproducer(batch_size, num_steps):
    word_to_id, id_to_word, data = _build_vocab('F:\\ml\\code\\lstm\\1.txt')
    datalen = len(data)
    batchlen = datalen//batch_size
    epcho_size = (batchlen - 1)//num_steps


    data = tf.reshape(data[0: batchlen*batch_size], [batch_size,batchlen])
    i = tf.train.range_input_producer(epcho_size, shuffle=False).dequeue()
    x = tf.slice(data, [0,i*num_steps],[batch_size, num_steps])
    y = tf.slice(data, [0,i*num_steps+1],[batch_size, num_steps])
    x.set_shape([batch_size, num_steps])
    y.set_shape([batch_size, num_steps])
    return x,y,id_to_word

2、建立lstm模型:

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(size, forget_bias = 0.5)
lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob = keep_prob)
cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell], num_layers)

3、根据训练数据输出误差反向调整模型

with tf.variable_scope("Model", reuse = None, initializer = initializer):#tensorflow主要通过变量空间来实现共享变量
    with tf.variable_scope("r", reuse = None, initializer = initializer):
        softmax_w = tf.get_variable('softmax_w', [size, vocab_size])
        softmax_b = tf.get_variable('softmax_b', [vocab_size])
    with tf.variable_scope("RNN", reuse = None, initializer = initializer):
        for time_step in range(num_steps):
            if time_step > 0: tf.get_variable_scope().reuse_variables()
            (cell_output, state) = cell(inputs[:, time_step, :], state,)
            outputs.append(cell_output)
            
        output = tf.reshape(outputs, [-1,size])
        
        logits = tf.matmul(output, softmax_w) + softmax_b
        loss = tf.nn.seq2seq.sequence_loss_by_example([logits], [tf.reshape(targets,[-1])], [tf.ones([batch_size*num_steps])])
        
        global_step = tf.Variable(0)
        learning_rate = tf.train.exponential_decay(
        10.0, global_step, 5000, 0.1, staircase=True)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        gradients, v = zip(*optimizer.compute_gradients(loss))
        gradients, _ = tf.clip_by_global_norm(gradients, 1.25)
        optimizer = optimizer.apply_gradients(zip(gradients, v), global_step=global_step)

4、预测新一轮输出

teststate = test_initial_state
        (celloutput,teststate)= cell(test_inputs, teststate)
        partial_logits = tf.matmul(celloutput, softmax_w) + softmax_b
        partial_logits = tf.nn.softmax(partial_logits)

5、根据之前建立的操作,运行tensorflow会话

sv = tf.train.Supervisor(logdir=None)
with sv.managed_session() as session:
    costs = 0
    iters = 0
    for i in range(1000):
        _,l= session.run([optimizer, cost])
        costs += l
        iters +=num_steps
        perplextity = np.exp(costs / iters)
        if i%20 == 0:
            print(perplextity)
        if i%100 == 0:
            p = random_distribution()
            b = sample(p)
            sentence = id_to_word[b[0]]
            for j in range(200):
                test_output = session.run(partial_logits, feed_dict={test_input:b})
                b = sample(test_output)
                sentence += id_to_word[b[0]]
            print(sentence)    

其中,使用sv.managed_session()后,在此会话间,将不能修改graph。如果采用普通的session,程序将会阻塞于session.run(),对于这个问题,我还是很疑惑,希望理解的人帮忙解答下。

代码地址位于https://github.com/summersunshine1/datamining/tree/master/lstm,运行时只需将readdata中文件路径修改即可。作为深度学习的入门小白,希望大家多多指点。

运行结果如下:



LSTM (Long Short-Term Memory) 是一种用于处理序列数据的深度学习模型,它可以用于生成小说等文本数据。以下是一个简单的 LSTM 小说生成的示例代码,你可以参考一下: ```python from keras.models import Sequential from keras.layers import LSTM, Dense, Activation import numpy as np # 定义训练数据 text = "小明喜欢玩游戏,他经常熬夜玩游戏,有一天,他玩了一个通宵,第二天早上醒来,发现自己穿越到了游戏世界……" chars = list(set(text)) # 获取所有唯一字符 char_to_int = dict((c, i) for i, c in enumerate(chars)) # 字符到整数的映射 int_to_char = dict((i, c) for i, c in enumerate(chars)) # 整数到字符的映射 seq_length = 10 # 序列长度 dataX = [] dataY = [] for i in range(0, len(text)-seq_length): seq_in = text[i:i+seq_length] seq_out = text[i+seq_length] dataX.append([char_to_int[char] for char in seq_in]) dataY.append(char_to_int[seq_out]) n_patterns = len(dataX) # 将训练数据转换为 LSTM 可接受的格式 X = np.reshape(dataX, (n_patterns, seq_length, 1)) X = X / float(len(chars)) y = np_utils.to_categorical(dataY) # 定义 LSTM 模型 model = Sequential() model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2]), return_sequences=True)) model.add(Dropout(0.2)) model.add(LSTM(256)) model.add(Dropout(0.2)) model.add(Dense(y.shape[1], activation='softmax')) # 编译模型 model.compile(loss='categorical_crossentropy', optimizer='adam') # 训练模型 model.fit(X, y, epochs=20, batch_size=64) # 生成新文本 start = np.random.randint(0, len(dataX)-1) pattern = dataX[start] print("Seed:") print("\"", ''.join([int_to_char[value] for value in pattern]), "\"") for i in range(100): x = np.reshape(pattern, (1, len(pattern), 1)) x = x / float(len(chars)) prediction = model.predict(x, verbose=0) index = np.argmax(prediction) result = int_to_char[index] seq_in = [int_to_char[value] for value in pattern] print(result) pattern.append(index) pattern = pattern[1:len(pattern)] ``` 这段代码会训练一个基于 LSTM 的模型,然后根据给定的初始文本生成新的文本。当然,这只是一个简单的示例,你可以通过修改模型结构、调整超参数等手段提高生成文本的质量。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值