本文在原文的基础上添加了一些注释、运行结果和修改了少量的代码。
1. 介绍
LSTM(Long Short Term Memory)是一种特殊的循环神经网络,在许多任务中,LSTM表现得比标准的RNN要出色得多。
关于LSTM的介绍可以看参考文献1和2。本文重点在使用LSTM实现一个分类器。
2. 如何在 keras 中使用LSTM
本文主要测试 keras
使用Word Embeddings
并进行分类的测试。代码是在keras
官方文档的示例中修改而来。IPython代码链接
2.1 Word Embeddings 数据集
使用了stanford的GloVe作为词向量集,这个直接下载训练好的词向量文件。直接字典搜索,得到文本词向量。Glove数据集下载文本测试数据是20_newsgroup
This data set is a collection of 20,000 messages, collected from 20 different netnews newsgroups. One thousand messages from each of the twenty newsgroups were chosen at random and partitioned by newsgroup name. The list of newsgroups from which the messages were chose is as follows:
alt.atheism
talk.politics.guns
talk.politics.mideast
talk.politics.misc
talk.religion.misc
soc.religion.christian
comp.sys.ibm.pc.hardware
comp.graphics
comp.os.ms-windows.misc
comp.sys.mac.hardware
comp.windows.x
rec.autos
rec.motorcycles
rec.sport.baseball
rec.sport.hockey
sci.crypt
sci.electronics
sci.space
sci.med
misc.forsale
我们通过label标注把message分成不同的20个类别。每个newsgroup被map到一个数值label上。
需要用到的模块
import numpy as np
import os
import sys
import random
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.utils.np_utils import to_categorical
from keras.models import Sequential
from keras.layers import Embedding, LSTM, Dense, Activation
2.2 数据预处理
这部分是设定训练相关参数,并且读入训练好的GloVe词向量文件。把文本读入进list里,一个文本存成一个str,变成一个[str]
BASE_DIR = '/home/lich/Workspace/Learning'
GLOVE_DIR = BASE_DIR + '/glove.6B/'
TEXT_DATA_DIR = BASE_DIR + '/20_newsgroup/'
MAX_SEQUENCE_LENGTH = 1000
MAX_NB_WORDS = 20000
EMBEDDING_DIM = 100
VALIDATION_SPLIT = 0.2
batch_size = 32
# first, build index mapping words in the embeddings set
# to their embedding vector
embeddings_index = {}
f = open(os.path.join(GLOVE_DIR, 'glove.6B.100d.txt'))
for line in f:
values = line.split()
word = values[0]
coefs = np.asarray(values[1:], dtype='float32')
embeddings_index[word] = coefs
f.close()
print('Found %s word vectors.' % len(embeddings_index))
#Found 400000 word vectors.
# second, prepare text samples and their labels
print('Processing text dataset')
texts = [] # list of text samples
labels_index = {} # dictionary mapping label name to numeric id
labels = [] # list of label ids
for name in sorted(os.listdir(TEXT_DATA_DIR)):
path = os.path.join(TEXT_DATA_DIR, name)
if os.path.isdir(path):
label_id = len(labels_index)
labels_index[name] = label_id
for fname in sorted(os.listdir(path)):
if fname.isdigit():
fpath = os.path.join(path, fname)
if sys.version_info < (3,):
f = open(fpath)
else:
f = open(fpath, encoding='latin-1')
texts.append(f.read())
f.close()
labels.append(label_id)
print('Found %s texts.' % len(texts))
#Found 19997 texts.
embeddings_index 里面是这样:
embeddings_index['hi']
"""
array([ 0.1444 , 0.23978999, 0.96692997, 0.31628999, -0.36063999,
-0.87673998, 0.098512 , 0.31077999, 0.47929001, 0.27175 ,
0.30004999, -0.23732001, -0.31516999, 0.17925 , 0.61773002,
0.59820998, 0.49489 , 0.3423 , -0.078034 , 0.60211998,
0.18683 , 0.52069998, -0.12331 , 0.48313001, -0.24117 ,
0.59696001, 0.61078 , -0.84413999, 0.27660999, 0.068767 ,
-1.13880002, 0.089544 , 0.89841998, 0.53788 , 0.10841 ,
-0.10038 , 0.12921 , 0.11476 , -0.47400001, -0.80489999,
0.95999998, -0.36601999, -0.43019 , -0.39807999, -0.096782 ,
-0.71183997, -0.31494001, 0.82345998, 0.42179 , -0.692049