该代码演示了 采用预定义的词向量 来进行文本分类的功能。
数据集采用的是 20_newsgroup,18000篇新闻文章,一共涉及到20种话题,训练神经网络,给定一个新闻,可以识别出是属于哪一类的新闻。
该代码我在Linux下运行没问题,但是在Windows下会报错:
Traceback (most recent call last):
File "pretrained_word_embeddings.py", line 43, in <module>
for line in f:
UnicodeDecodeError: 'gbk' codec can't decode byte 0x93 in position 5456: illegal multibyte sequence
需要修改一下代码,把
with open(os.path.join(GLOVE_DIR, 'glove.6B.100d.txt')) as f:
修改为:
with open(os.path.join(GLOVE_DIR, 'glove.6B.100d.txt'), encoding='utf-8') as f:
显示获取各单词的预训练的向量 embeddings_index;
然后代码
sequences = tokenizer.texts_to_sequences(texts)
会把
Archive-name: atheism/resources
Alt-atheism-archive-name: resources
Last-modified: 11 December 1992
Version: 1.0
Atheist Resources
Addresses of Atheist Organizations
USA
FREEDOM FROM RELIGION FOUNDATION
Darwin fish bumper stickers and assorted other atheist paraphernalia are
available from the Freedom From Religion Foundation in the US.
转换为类似下面的结构:
[1237, 273, 1213, 1439, 1071, 1213, 1237, 273, 1439, 192, 2515, 348, 2964, 779, 332, 28, 45, 1628, 1439, 2516, 3, 1628, 2144, 780, 937, 29, 441, 2770, 8854, 4601, 7969, 11979, 5, 12806, 75, 1628, 19, 229, 29, 1, 937, 29, 441, 2770, 6, 1, 118, 558, 2, 90, 106, 482, 3979, 6602, 5375, 1871, 12260, 1632, 17687, 1828, 5101, 1828, 5101, 788, 1, 8854, 4601, 96, 4, 4601, 5455, 64, 1, 751, 563, 1716, 15, 71, 844, 24, 20, 1971, 5, 1, 389, 8854, 744, 1023, 1, 7762, 1300, 2912, 4601, 8, 73, 1698, 6, 1, 118, 558, 2, 1828, 5101, 16500, 13447, 73, 1261, 10982, 170, 66, 6, 1, 869, 2235, 2544, 534, 34, 79, 8854, 4601, 29, 6603, 3388, 264, 1505, 535, 49, 12, 343, 66, 60, 155, 2, 6603, 1043, 1, 427, 8, 73, 1698, 618, 4601, 417, 1628, 632, 11716, 4602, 814, 1628, 691, 3, 1, 467, 2163, 3, 2266, 7491, 5, 48, 15, 40, 135, 378, 8, 1, 467, 6359, 30, 101, 90, 1781, 5, 115, 101, 417, 1628, 632, 17061, 1448, 4317, 45, 860, 73, 1611, 2455, 3343, 467, 7491, 13132, 5814, 1301, 1781, 1, 467, 9477, 667, 11716, 323, 15, 1, 1074, 802, 332, 3, 1, 467, 558, 2, 417, 1628, 632, 90, 106, 482, 2030, 2408, 22, 13799, 853, 2030, 2408, 1871, 3793, 12524, 439, 3793, 13448, 691, 788, 691, 502, 1552, 11221, 116, 993, 558, 2, 2974, 996, 7674, 1184, 1346, 108, 828, 1871, 9478, 12807, 32, 7675, 460, 61, 110, 16, 3362, 22, 1950, 8, 691, 1711, 5622, 233, 1346, 1428, 4623, 1260, 12, 16501, 32, 1044, 7854, 564, 3955, 16501, 5, 1, 500, 3, 564, 27, 4602, 4, 9648, 2913, 10746, 558, 2, 7128, 97, 2456, 2420, 4623, 1260, 12, 16501, 90, 106, 482, 13133, 1346, 1428, 797, 2652, 632, 2366, 445, 3955, 681, 2477, 288, 1184,
理由是:
print("tokenizer.index_word[1237]:", tokenizer.index_word[1237])
print("tokenizer.index_word[273]:", tokenizer.index_word[273])
打印结果为:
tokenizer.index_word[1237]: archive
tokenizer.index_word[273]: name
就是把对应的word转为对应的index。
根据预定义的词向量 embedding_matrix,来添加 Embedding 层
神经网络结构为:
____________________________________________________________________________________________________
Layer (type) Output Shape Param #
====================================================================================================
input_1 (InputLayer) (None, 1000) 0
____________________________________________________________________________________________________
embedding_1 (Embedding) (None, 1000, 100) 2000000
____________________________________________________________________________________________________
conv1d_1 (Conv1D) (None, 996, 128) 64128
____________________________________________________________________________________________________
max_pooling1d_1 (MaxPooling1D) (None, 199, 128) 0
____________________________________________________________________________________________________
conv1d_2 (Conv1D) (None, 195, 128) 82048
____________________________________________________________________________________________________
max_pooling1d_2 (MaxPooling1D) (None, 39, 128) 0
____________________________________________________________________________________________________
conv1d_3 (Conv1D) (None, 35, 128) 82048
____________________________________________________________________________________________________
global_max_pooling1d_1 (GlobalMaxPooling1D) (None, 128) 0
____________________________________________________________________________________________________
dense_1 (Dense) (None, 128) 16512
____________________________________________________________________________________________________
dense_2 (Dense) (None, 20) 2580
====================================================================================================
Total params: 2,247,316
Trainable params: 247,316
Non-trainable params: 2,000,000
____________________________________________________________________________________________________
None