lstm代码_基于TensorFlow使用LSTM入门:文本分类

简介

任务:基于TensorFlow实现LSTM的入门,简单实现情感分类任务[1]

模型整体结构:文本输入层--embedding层--LSTM层--全连接层--输出层。

说明:以下只是部分代码,如需完整代码,可见文章末尾。

数据集

采用伪造的数据集,文本如下。

# 数据

文本编码

文本编码,将文字转化为数字,以便输入模型。用TensorFlow自带函数VocabularyProcessor实现[2],代码如下。

all_texts = positive_texts + negative_texts
max_document_length = 4
vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)
datas = np.array(list(vocab_processor.fit_transform(all_texts)))

构造模型

# 定义输入输出
datas_placeholder = tf.placeholder(tf.int32, [None, max_document_length])
labels_placeholder = tf.placeholder(tf.int32, [None])

# 建立embeddings矩阵
embeddings = tf.get_variable("embeddings", [vocab_size, embedding_size], initializer=tf.truncated_normal_initializer)

# 将词索引号转换为词向量[None, max_document_length] => [None, max_document_length, embedding_size]
embedded = tf.nn.embedding_lookup(embeddings, datas_placeholder)

将数据转化为LSTM输入格式,数组内每一个元素代表一个时间戳,代码如下。

# 转换为LSTM的输入格式,要求是数组,数组的每个元素代表某个时间戳一个Batch的数据
rnn_input = tf.unstack(embedded, max_document_length, axis=1)

基于TensoFlow定义LSTM,函数参数讲解见参考文献[3]

# 定义LSTM,20为输出神经元数量
lstm_cell = BasicLSTMCell(20, forget_bias=1.0)
rnn_outputs, rnn_states = static_rnn(lstm_cell, rnn_input, dtype=tf.float32)

# 利用LSTM最后的输出进行预测,取最后一个时间戳的输出神经元,在其上加全连接层
logits = tf.layers.dense(rnn_outputs[-1], num_classes)

predicted_labels = tf.argmax(logits, axis=1)

# 定义损失和优化器
losses= tf.nn.softmax_cross_entropy_with_logits(
    labels=tf.one_hot(labels_placeholder, num_classes),
    logits=logits
)

模型训练代码略,如需完整代码可留邮箱。

训练结果

581a10a5c2467f5021ddfeb0c6fc594b.png

以上欢迎交流

GIT地址

进入Git地址后,希望大家点亮右上角的星(如下图)。您的支持将是我的最大动力。

2c542ed5e09cecdd9f9f51522edb6cc0.png
https://github.com/a-bean-sprout/LSTM_Simple_Demo​github.com

参考

  1. ^https://mp.weixin.qq.com/s/HgzCOgpEj9iOfTp_hl9VHA
  2. ^https://www.jianshu.com/p/db400a569730
  3. ^https://blog.csdn.net/u013230189/article/details/82808362
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值