深度学习之RNN②——字符序列(ihello)

import tensorflow as tf
import numpy as np

tf.set_random_seed(777)  # reproducibility

idx2char = ['h', 'i', 'e', 'l', 'o']
# Teach hello: hihell -> ihello
x_data = [[0, 1, 0, 2, 3, 3]]   # hihell
x_one_hot = [[[1, 0, 0, 0, 0],   # h 0
              [0, 1, 0, 0, 0],   # i 1
              [1, 0, 0, 0, 0],   # h 0
              [0, 0, 1, 0, 0],   # e 2
              [0, 0, 0, 1, 0],   # l 3
              [0, 0, 0, 1, 0]]]  # l 3'
#x_one_hot维度(1,6,5)

y_data = [[1, 0, 2, 3, 3, 4]]    # ihello

num_classes = 5   #遗忘门和输入门的维度长度相等
input_dim = 5     # 独热编码的大小
hidden_size = 8   # LSTM单元中的神经元数量,即输出神经元数量
batch_size = 1    # 一次输入a<t>和x<t>的个数
sequence_length = 6  # |ihello| == 6
learning_rate = 0.1

X = tf.placeholder(tf.float32, [None, sequence_length, input_dim])  # X的维度(1,6,5Y = tf.placeholder(tf.int32, [None, sequence_length])  # Y的维度(1,6)

#输入的cellBasicLSTMCell类是最基本的LSTM循环神经网络单元。 输入参数和BasicRNNCell差不多
cell = tf.contrib.rnn.BasicLSTMCell(num_units=hidden_size, state_is_tuple=True)
# state_is_tuple 官方建议设置为True。此时,输入和输出的states为c(cell状态)和h(输出)的二元组 输入、输出、cell的维度相同,都是 batch_size * num_units,

#得到一个全为0的初始状态
initial_state = cell.zero_state(batch_size, tf.float32)

# outputs形状为 [batch_size,max_time, output_size ](要求rnn输入与rnn输出形状保持一致)
#动态rnn 按照时刻进行调用
outputs, _states = tf.nn.dynamic_rnn(cell, X, initial_state=initial_state, dtype=tf.float32)

# FC layer (1, 6, 8)
X_for_fc = tf.reshape(outputs, [-1, hidden_size])#(6,8)
# 输出个数由num_outputs指定
outputs = tf.layers.dense(X_for_fc,num_classes,activation=None)# (6,5)

# reshape out for sequence_loss
outputs = tf.reshape(outputs, [batch_size, sequence_length, num_classes])# (1, 6, 5)

loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=Y,logits=outputs))

train = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

prediction = tf.argmax(outputs, axis=2)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(50):
        l, _ = sess.run([loss, train], feed_dict={X: x_one_hot, Y: y_data})
        result = sess.run(prediction, feed_dict={X: x_one_hot})
        print(i, "loss:", l, "prediction: ", result, "true Y: ", y_data)

        # print char using dic
        result_str = [idx2char[c] for c in np.squeeze(result)]
        print(r"\tPrediction str: ", ''.join(result_str))

0 loss: 1.5785435 prediction:  [[3 3 3 3 3 3]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  llllll
1 loss: 1.421351 prediction:  [[3 3 3 3 3 3]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  llllll
2 loss: 1.278454 prediction:  [[3 3 3 3 3 3]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  llllll
3 loss: 1.1136965 prediction:  [[2 3 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  elello
4 loss: 0.9037611 prediction:  [[2 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ehello
5 loss: 0.7068203 prediction:  [[2 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ehello
6 loss: 0.5666248 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
7 loss: 0.43629977 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
8 loss: 0.32687488 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
9 loss: 0.26123345 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
10 loss: 0.19413996 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
11 loss: 0.14754416 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
12 loss: 0.10863178 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
13 loss: 0.0848211 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
14 loss: 0.06194036 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
15 loss: 0.046265546 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
16 loss: 0.037002042 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
17 loss: 0.028291948 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
18 loss: 0.02109315 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
19 loss: 0.016766297 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
20 loss: 0.013788201 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
21 loss: 0.010870796 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
22 loss: 0.008500288 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
23 loss: 0.006932325 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
24 loss: 0.0059833876 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
25 loss: 0.0054096035 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
26 loss: 0.004837516 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
27 loss: 0.0041285097 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
28 loss: 0.003538878 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
29 loss: 0.003170958 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
30 loss: 0.0029489323 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
31 loss: 0.002775755 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
32 loss: 0.002583687 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
33 loss: 0.0023563588 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
34 loss: 0.0021243908 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
35 loss: 0.0019306702 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
36 loss: 0.0017992812 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
37 loss: 0.0017220915 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
38 loss: 0.0016608015 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
39 loss: 0.0015752836 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
40 loss: 0.0014646194 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
41 loss: 0.0013580621 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
42 loss: 0.0012761353 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
43 loss: 0.0012197478 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
44 loss: 0.0011792878 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
45 loss: 0.0011437086 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
46 loss: 0.0011050947 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
47 loss: 0.0010614416 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
48 loss: 0.0010153268 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
49 loss: 0.0009716902 prediction:  [[1 0 2 3 3 4]] true Y:  [[1, 0, 2, 3, 3, 4]]
\tPrediction str:  ihello
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值