微信公众号:数据挖掘与分析学习
1.导入所需的库
from __future__ import print_function from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf |
2.数据准备
mnist=input_data.read_data_sets('/data/machine_learning/mnist/',one_hot=True)
X=tf.placeholder("float",[None,num_input]) Y=tf.placeholder("float",[None,num_classes]) |
3.参数设置
#基本参数设置 learning_rate=0.1 num_steps=500 batch_size=128 display_step=100
#网络参数设置 n_hidden_1=256 #第一个隐藏层的神经元 n_hidden_2=256 #第二个隐藏层的神经元 num_input=784 #输入的特征数量 num_classes=10 #标签数
#权重和偏置 weights={ 'h1':tf.Variable(tf.random_normal([num_input,n_hidden_1])), 'h2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2])), 'out':tf.Variable(tf.random_normal([n_hidden_2,num_classes])) } biases={ 'b1':tf.Variable(tf.random_normal([n_hidden_1])), 'b2':tf.Variable(tf.random_normal([n_hidden_2])), 'out':tf.Variable(tf.random_normal([num_classes])) } |
4.模型构建
#创建模型 def neural_net(x): #第一个隐藏的全连接层 layer_1=tf.add(tf.matmul(x,weights['h1']),biases['b1']) #第二个隐藏的全连接层 layer_2=tf.add(tf.matmul(layer_1,weights['h2']),biases['b2']) #输出层 out_layer=tf.matmul(layer_2,weights['out'])+biases['out'] return out_layer |
5.模型训练和测试
logits=neural_net(X) prediction=tf.nn.softmax(logits)
#定义损失和优化器 loss_op=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=Y))#交叉熵损失 optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate) #优化器 train_op=optimizer.minimize(loss_op)#最小化损失
correct_pred=tf.equal(tf.argmax(prediction,1),tf.argmax(Y,1)) accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))
init=tf.global_variables_initializer() |
with tf.Session() as sess: sess.run(init) for step in range(1,num_steps+1): batch_x,batch_y=mnist.train.next_batch(batch_size) sess.run(train_op,feed_dict={X:batch_x,Y:batch_y}) if step % display_step == 0 or step == 1: # Calculate batch loss and accuracy loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x, Y: batch_y}) print("Step " + str(step) + ", Minibatch Loss= " + \ "{:.4f}".format(loss) + ", Training Accuracy= " + \ "{:.3f}".format(acc))
print("Optimization Finished!")
# Calculate accuracy for MNIST test images print("Testing Accuracy:", \ sess.run(accuracy, feed_dict={X: mnist.test.images, Y: mnist.test.labels})) |