20200818 -
引言
前面文章中,介绍了字符级别的文本生成《LSTM生成文本(字符级别),在字符级别的生成过程中,利用滑动窗口的形式来持续生成文本。本文中介绍看到的另外一篇基于单词的生成形式。
LSTM文本生成
本篇文章中,主要参考了kaggle上的一篇文章[1],在模型中,使用了embedding层,然后输入其实是句子。但是感觉他的代码部分并不是非常友好,也可能是我对模型的使用有些忘记了。
数据预处理
def generate_padded_sequences(input_sequences):
max_sequence_len = max([len(x) for x in input_sequences])
input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre'))
predictors, label = input_sequences[:,:-1],input_sequences[:,-1]
label = ku.to_categorical(label, num_classes=total_words)
return predictors, label, max_sequence_len
predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences)
从上文中,可以看出,其使用的方式跟前面字符级别文章的一样,也是输入一段字符,然后最后一个单词作为输出。
模型
def create_model(max_sequence_len, total_words):
input_len = max_sequence_len - 1
model = Sequential()
# Add Input Embedding Layer
model.add(Embedding(total_words, 10, input_length=input_len))
# Add Hidden Layer 1 - LSTM Layer
model.add(LSTM(100))
model.add(Dropout(0.1))
# Add Output Layer
model.add(Dense(total_words, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
return model
model = create_model(max_sequence_len, total_words)
model.summary()
基本上理解是没有问题的,一个embedding层,然后就是LSTM,最后一个多分类。
预测部分
def generate_text(seed_text, next_words, model, max_sequence_len):
for _ in range(next_words):
token_list = tokenizer.texts_to_sequences([seed_text])[0]
token_list = pad_sequences([token_list], maxlen=max_sequence_len-1, padding='pre')
predicted = model.predict_classes(token_list, verbose=0)
output_word = ""
for word,index in tokenizer.word_index.items():
if index == predicted:
output_word = word
break
seed_text += " "+output_word
return seed_text.title()
print (generate_text("united states", 5, model, max_sequence_len))
print (generate_text("preident trump", 4, model, max_sequence_len))
print (generate_text("donald trump", 4, model, max_sequence_len))
print (generate_text("india and china", 4, model, max_sequence_len))
print (generate_text("new york", 4, model, max_sequence_len))
print (generate_text("science and technology", 5, model, max_sequence_len))
最后是一个预测的过程,也是将生成的文本重新添加到之前的种子中。
后记
感觉这篇文章的内容和之前的字符级别在数据处理(准备输入输出上)没有本质的区别,大致还是利用滑动窗口的形式来预测,只不过是使用了单词级别。
还是要再看看更多的内容,那么后续的模型改进中,是对模型的机构进行了改进,还是从数据的角度呢?