import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.models.rnn import rnn, rnn_cell
mnist = input_data.read_data_sets('data/', one_hot=True)
train_img = mnist.train.images
train_lbl = mnist.train.labels
test_img = mnist.test.images
test_lbl = mnist.test.labels
dim_input = 28
dim_hidden = 128
dim_output = 10
nsteps = 28
weights = {
'i2h' : tf.Variable(tf.random_normal([dim_input, dim_hidden], stddev=0.1)),
'fc' : tf.Variable(tf.random_normal([dim_hidden, dim_output], stddev=0.1))}
bias = {
'i2h' :
Tensorflow: recurrent neural network (mnist basic)
最新推荐文章于 2019-10-17 20:42:40 发布