分析:
看 TensorFlow 也有一段时间了,准备按照 GitHub 上的教程,敲出来,顺便整理一下思路。
RNN部分
- 定义参数,包括数据相关,训练相关。
- 定义模型,损失函数,优化函数。
- 训练,准备数据,输入数据,输出结果。
代码:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib import rnn
mnist=input_data.read_data_sets("./data",one_hot=True)
training_rate=0.001
training_iters=100000
batch_size=128
display_step=10
n_input=28
n_steps=28
n_hidden=128
n_classes=10
x=tf.placeholder("float",[None,n_steps,n_input])
y=tf.placeholder("float",[None,n_classes])
weights={'out':tf.Variable(tf.random_normal([n_hidden,n_classes]))}
biases={'out':tf.Variable(tf.random_normal([n_classes]))}
def RN