github博客传送门
博客园传送门
本章所需知识:
- 没有基础的请观看深度学习系列视频
- tensorflow
- Python基础
资料下载链接:
深度学习基础网络模型(mnist手写体识别数据集)
MNIST数据集手写体识别(CNN实现)
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets('../MNIST_data/', one_hot=True)
class SEQ2SEQNet:
def __init__(self):
self.x = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28], name='input_x')
self.y = tf.placeholder(dtype=tf.float32, shape=[None, 10], name='input_y')
self.fc1_w = tf.Variable(tf.truncated_normal(shape=[128, 10], dtype=tf.float32, stddev=tf.sqrt(1 / 10)))
self.fc1_b = tf.Variable(tf.zeros(shape=[10], dtype=tf.float32))
def forward(self):
with tf.variable_scope('encode'):
self.encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(128)
self.encoder_states = self.encoder_cell.zero_state(100, dtype=tf.float32)
self.encoder_output, self.encoder_state = tf.nn.dynamic_rnn(self.encoder_cell, self.x, initial_state=self.encoder_states, time_major=False)
self.flat = tf.transpose(self.encoder_output, [1, 0, 2])[-1]
self.flat1 = tf.expand_dims(self.flat, axis=1)
self.flat2 = tf.tile(self.flat1, [1, 4, 1])
with tf.variable_scope('decode'):
self.decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(128)
self.decoder_states = self.decoder_cell.zero_state(100, dtype=tf.float32)
self.decoder_output, self.decoder_state = tf.nn.dynamic_rnn(self.decoder_cell, self.flat2, initial_state=self.decoder_states, time_major=False)
self.flat3 = tf.transpose(self.decoder_output, [1, 0, 2])[-1]
self.fc_y = tf.nn.relu(tf.matmul(self.flat3, self.fc1_w)+self.fc1_b)
self.output = tf.nn.softmax(self.fc_y)
def backword(self):
self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y, logits=self.fc_y))
self.opt = tf.train.AdamOptimizer().minimize(self.cost)
def acc(self):
self.acc1 = tf.equal(tf.argmax(self.output, 1), tf.argmax(self.y, 1))
self.accaracy = tf.reduce_mean(tf.cast(self.acc1, dtype=tf.float32))
if __name__ == '__main__':
net = SEQ2SEQNet()
net.forward()
net.backward()
net.acc()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for i in range(10000):
ax, ay = mnist.train.next_batch(100)
ax_batch = ax.reshape(-1, 28, 28)
loss, output, accaracy, _ = sess.run(fetches=[net.cost, net.output, net.accaracy, net.opt], feed_dict={net.x: ax_batch, net.y: ay})
if i % 100 == 0:
test_ax, test_ay = mnist.test.next_batch(100)
test_ax_batch = test_ax.reshape(-1, 28, 28)
test_output = sess.run(fetches=net.output, feed_dict={net.x: test_ax_batch})
test_acc = sess.run(tf.equal(tf.argmax(test_output, 1), tf.argmax(test_ay, 1)))
test_accaracy = sess.run(tf.reduce_mean(tf.cast(test_acc, dtype=tf.float32)))
print(test_accaracy)
最后附上训练截图:
![SEQ2SEQ](https://i-blog.csdnimg.cn/blog_migrate/910c35c1c2c73ef9d16661e38cead8a6.png)