keras 的 example 文件 pretrained_word_embeddings.py 解析

该代码演示了 采用预定义的词向量 来进行文本分类的功能。

数据集采用的是 20_newsgroup,18000篇新闻文章,一共涉及到20种话题,训练神经网络,给定一个新闻,可以识别出是属于哪一类的新闻。

 

该代码我在Linux下运行没问题,但是在Windows下会报错:

Traceback (most recent call last):
  File "pretrained_word_embeddings.py", line 43, in <module>
    for line in f:
UnicodeDecodeError: 'gbk' codec can't decode byte 0x93 in position 5456: illegal multibyte sequence

需要修改一下代码,把

with open(os.path.join(GLOVE_DIR, 'glove.6B.100d.txt')) as f:

修改为:

with open(os.path.join(GLOVE_DIR, 'glove.6B.100d.txt'), encoding='utf-8') as f:

 

显示获取各单词的预训练的向量 embeddings_index;

 

然后代码

sequences = tokenizer.texts_to_sequences(texts)

会把

Archive-name: atheism/resources
Alt-atheism-archive-name: resources
Last-modified: 11 December 1992
Version: 1.0

                              Atheist Resources

                      Addresses of Atheist Organizations

                                     USA

FREEDOM FROM RELIGION FOUNDATION

Darwin fish bumper stickers and assorted other atheist paraphernalia are
available from the Freedom From Religion Foundation in the US.

转换为类似下面的结构:

[1237, 273, 1213, 1439, 1071, 1213, 1237, 273, 1439, 192, 2515, 348, 2964, 779, 332, 28, 45, 1628, 1439, 2516, 3, 1628, 2144, 780, 937, 29, 441, 2770, 8854, 4601, 7969, 11979, 5, 12806, 75, 1628, 19, 229, 29, 1, 937, 29, 441, 2770, 6, 1, 118, 558, 2, 90, 106, 482, 3979, 6602, 5375, 1871, 12260, 1632, 17687, 1828, 5101, 1828, 5101, 788, 1, 8854, 4601, 96, 4, 4601, 5455, 64, 1, 751, 563, 1716, 15, 71, 844, 24, 20, 1971, 5, 1, 389, 8854, 744, 1023, 1, 7762, 1300, 2912, 4601, 8, 73, 1698, 6, 1, 118, 558, 2, 1828, 5101, 16500, 13447, 73, 1261, 10982, 170, 66, 6, 1, 869, 2235, 2544, 534, 34, 79, 8854, 4601, 29, 6603, 3388, 264, 1505, 535, 49, 12, 343, 66, 60, 155, 2, 6603, 1043, 1, 427, 8, 73, 1698, 618, 4601, 417, 1628, 632, 11716, 4602, 814, 1628, 691, 3, 1, 467, 2163, 3, 2266, 7491, 5, 48, 15, 40, 135, 378, 8, 1, 467, 6359, 30, 101, 90, 1781, 5, 115, 101, 417, 1628, 632, 17061, 1448, 4317, 45, 860, 73, 1611, 2455, 3343, 467, 7491, 13132, 5814, 1301, 1781, 1, 467, 9477, 667, 11716, 323, 15, 1, 1074, 802, 332, 3, 1, 467, 558, 2, 417, 1628, 632, 90, 106, 482, 2030, 2408, 22, 13799, 853, 2030, 2408, 1871, 3793, 12524, 439, 3793, 13448, 691, 788, 691, 502, 1552, 11221, 116, 993, 558, 2, 2974, 996, 7674, 1184, 1346, 108, 828, 1871, 9478, 12807, 32, 7675, 460, 61, 110, 16, 3362, 22, 1950, 8, 691, 1711, 5622, 233, 1346, 1428, 4623, 1260, 12, 16501, 32, 1044, 7854, 564, 3955, 16501, 5, 1, 500, 3, 564, 27, 4602, 4, 9648, 2913, 10746, 558, 2, 7128, 97, 2456, 2420, 4623, 1260, 12, 16501, 90, 106, 482, 13133, 1346, 1428, 797, 2652, 632, 2366, 445, 3955, 681, 2477, 288, 1184, 

理由是:

print("tokenizer.index_word[1237]:", tokenizer.index_word[1237])
print("tokenizer.index_word[273]:", tokenizer.index_word[273])

打印结果为:

tokenizer.index_word[1237]: archive
tokenizer.index_word[273]: name

就是把对应的word转为对应的index。

 

根据预定义的词向量 embedding_matrix,来添加 Embedding 层

 

神经网络结构为:

____________________________________________________________________________________________________
Layer (type)                                 Output Shape                            Param #
====================================================================================================
input_1 (InputLayer)                         (None, 1000)                            0
____________________________________________________________________________________________________
embedding_1 (Embedding)                      (None, 1000, 100)                       2000000
____________________________________________________________________________________________________
conv1d_1 (Conv1D)                            (None, 996, 128)                        64128
____________________________________________________________________________________________________
max_pooling1d_1 (MaxPooling1D)               (None, 199, 128)                        0
____________________________________________________________________________________________________
conv1d_2 (Conv1D)                            (None, 195, 128)                        82048
____________________________________________________________________________________________________
max_pooling1d_2 (MaxPooling1D)               (None, 39, 128)                         0
____________________________________________________________________________________________________
conv1d_3 (Conv1D)                            (None, 35, 128)                         82048
____________________________________________________________________________________________________
global_max_pooling1d_1 (GlobalMaxPooling1D)  (None, 128)                             0
____________________________________________________________________________________________________
dense_1 (Dense)                              (None, 128)                             16512
____________________________________________________________________________________________________
dense_2 (Dense)                              (None, 20)                              2580
====================================================================================================
Total params: 2,247,316
Trainable params: 247,316
Non-trainable params: 2,000,000
____________________________________________________________________________________________________
None

 

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值