import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data/',one_hot=True)
batch_size = 50
learning_rate = 0.001
keep_prob = tf.placeholder(tf.float32)
n_batch = mnist.train.num_examples // batch_size
#import data
train_images = mnist.train.images
train_labels = mnist.train.labels
test_images = mnist.test.images
test_labels = mnist.test.labels
print("Loading data...")
print("training start....")
#construct graph
x_data = tf.placeholder(tf.float32,[None,784])
y_data = tf.placeholder(tf.float32,[None,10])
#initial variable
w1 =tf.Variable(tf.truncated_normal([784,600],stddev=0.1))
b1 = tf.Variable(tf.zeros([600])+0.1)
pre1 = tf.nn.tanh(tf.matmul(x_data,w1)+b1)
dro_pre1 = tf.nn.dropout(pre1,keep_prob)
w2 =tf.Variable(tf.truncated_normal([600,400],stdde
tensorflow分类任务MNIST数据集
最新推荐文章于 2024-04-06 20:34:24 发布