RNN具有记忆性,经过训练可以学习到时间序列数据的潜在规律,并可以结余这种规律随机生成新的序列。
本例通过将美国现有城市录入RNN,RNN学到城市名称的潜在规律后,随机生成新的城市名称。下图为全美棒球联盟的队徽,几乎每个城市都有自己的棒球队,并且会以城市的名称命名棒球队。
1、数据清洗与特征化
数据集文件位于../data/US_Cities.txt中
数据处理代码逻辑如下
path = "../data/US_Cities.txt"
maxlen = 20
file_lines = open(path, "r").read()
X, Y, char_idx = \
string_to_semi_redundant_sequences(file_lines, seq_maxlen=maxlen, redun_step=3)
2、RNN训练模型构建
这里是使用LSTM构建,设计如下所示
具体代码如下所示
g = tflearn.input_data(shape=[None, maxlen, len(char_idx)])
g = tflearn.lstm(g, 512, return_seq=True)
g = tflearn.dropout(g, 0.5)
g = tflearn.lstm(g, 512)
g = tflearn.dropout(g, 0.5)
g = tflearn.fully_connected(g, len(char_idx), activation='softmax')
g = tflearn.regression(g, optimizer='adam', loss='categorical_crossentropy',
learning_rate=0.001)
3、实例化序列器
m = tflearn.SequenceGenerator(g, dictionary=char_idx,
seq_maxlen=maxlen,
clip_gradients=5.0,
checkpoint_path='model_us_cities')
4、验证效果
使用随机种子,通过RNN模型随机生成城市名称
for i in range(40):
seed = random_sequence_from_string(file_lines, maxlen)
m.fit(X, Y, validation_set=0.1, batch_size=128,
n_epoch=1, run_id='us_cities')
print("-- TESTING...")
print("-- Test with temperature of 1.2 --")
print(m.generate(30, temperature=1.2, seq_seed=seed))
print("-- Test with temperature of 1.0 --")
print(m.generate(30, temperature=1.0, seq_seed=seed))
print("-- Test with temperature of 0.5 --")
print(m.generate(30, temperature=0.5, seq_seed=seed))
5、运行结果
训练样本Training samples共62106个,测试样本Validation samples共 6901个
Training samples: 62106
Validation samples: 6901
生成城市
......
| Adam | epoch: 007 | loss: 2.02764 -- iter: 61824/62106
Training Step: 3400 | total loss: 2.00343 | time: 316.567s
| Adam | epoch: 007 | loss: 2.00343 -- iter: 61952/62106
Training Step: 3401 | total loss: 2.00492 | time: 317.206s
| Adam | epoch: 007 | loss: 2.00492 -- iter: 62080/62106
Training Step: 3402 | total loss: 2.03897 | time: 330.817s
| Adam | epoch: 007 | loss: 2.03897 | val_loss: 1.96934 -- iter: 62106/62106
--
-- TESTING...
-- Test with temperature of 1.2 --
vajo
Navarino
Navarrax-Pankime
Waatlophi
Laperting
-- Test with temperature of 1.0 --
vajo
Navarino
Navarrey Halls
Eals Carton
Banon
ar
-- Test with temperature of 0.5 --
vajo
Navarino
Navarroe
Sauet Mallan
Mauntan
Corenn
6、完整源码
from __future__ import absolute_import, division, print_function
import os
from six import moves
import ssl
import tflearn
from tflearn.data_utils import *
path = "../data/US_Cities.txt"
maxlen = 20
file_lines = open(path, "r").read()
X, Y, char_idx = \
string_to_semi_redundant_sequences(file_lines, seq_maxlen=maxlen, redun_step=3)
g = tflearn.input_data(shape=[None, maxlen, len(char_idx)])
g = tflearn.lstm(g, 512, return_seq=True)
g = tflearn.dropout(g, 0.5)
g = tflearn.lstm(g, 512)
g = tflearn.dropout(g, 0.5)
g = tflearn.fully_connected(g, len(char_idx), activation='softmax')
g = tflearn.regression(g, optimizer='adam', loss='categorical_crossentropy',
learning_rate=0.001)
m = tflearn.SequenceGenerator(g, dictionary=char_idx,
seq_maxlen=maxlen,
clip_gradients=5.0,
checkpoint_path='model_us_cities')
for i in range(40):
seed = random_sequence_from_string(file_lines, maxlen)
m.fit(X, Y, validation_set=0.1, batch_size=128,
n_epoch=1, run_id='us_cities')
print("-- TESTING...")
print("-- Test with temperature of 1.2 --")
print(m.generate(30, temperature=1.2, seq_seed=seed))
print("-- Test with temperature of 1.0 --")
print(m.generate(30, temperature=1.0, seq_seed=seed))
print("-- Test with temperature of 0.5 --")
print(m.generate(30, temperature=0.5, seq_seed=seed))