源码来自http://deeplearning.net/tutorial/lstm.html
最近了解到lstm,这份官方源码应该是学习了解lstm的最佳范本了,为了防止自己遗忘,写下这篇解析文档。
先大致回顾一下lstm到底是什么,简单的说是一种时间递归网络,解决之前普通rnn梯度爆炸或消失,无法对间隔时间很长的知识记忆的缺点。
http://colah.github.io/posts/2015-08-Understanding-LSTMs/ 一篇很容易理解的lstm的文档。
本文拟从lstm官方文档的代码出发,解析步骤按照程序运行步骤步步解析,anyway,开撸。
if __name__ == '__main__':
# See function train for all possible parameter and there definition.
train_lstm(
max_epochs=100,
test_size=500,
)
程序开始,设置了两个量,max_epoches, test_size ,其他参数缺省调用。
转 def train_lstm函数开头一堆缺省调用,每个参数都有解释,应该很好理解。
说一个model_options = locals().copy()
model_options = locals().copy()
print("model options", model_options)
将函数中所有参数爬虫下来,保存为一个词典。
train, valid, test = load_data(n_words=n_words, valid_portion=0.05,
maxlen=maxlen)
load_data在imdb.py里
def load_data(path="imdb.pkl", n_words=100000, valid_portion=0.1, maxlen=None,
sort_by_len=True):
path = get_dataset_file(
path, "imdb.pkl",
"http://www.iro.umontreal.ca/~lisa/deep/data/imdb.pkl")
if path.endswith(".gz"):
f = gzip.open(path, 'rb')
else:
f = open(path, 'rb')
train_set = pickle.load(f)
test_set = pickle.load(f)
f.close()
if maxlen:
new_train_set_x = []
new_train_set_y = []
for x, y in zip(train_set[0], train_set[1]):
if len(x) < maxlen:
new_train_set_x.append(x)
new_train_set_y.append(y)
train_set = (new_train_set_x, new_train_set_y)
del new_train_set_x, new_train_set_y
# split training set into validation set
train_set_x, train_set_y = train_set
n_samples = len(train_set_x)
sidx = numpy.random.permutation(n_samples)
n_train = int(numpy.round(n_samples * (1. - valid_portion)))
valid_set_x = [train_set_x[s] for s in sidx[n_train:]]
valid_set_y = [train_set_y[s] for s in sidx[n_train:]]