Tensorflow入门到实战四(识别手写数字集mnist)

本文介绍了一种使用TensorFlow实现的手写数字识别模型。该模型利用MNIST数据集进行训练,通过添加dropout层来防止过拟合,采用softmax作为激活函数进行多分类预测。在训练过程中,使用交叉熵损失函数和梯度下降优化器进行参数更新,最终模型在测试集上的准确率达到84.26%。
摘要由CSDN通过智能技术生成

手写数字集mnist

任然是集合作为特征数*样本数,特征数代表了某层神经元数量

wx_plus_b = tf.nn.dropout(wx_plus_b,keep_prob)

dropout可以解决过拟合

def add_layer(inputs,in_size,out_size,activation_function=None):
    Weights = tf.Variable(tf.random_normal([out_size,in_size]))
    bias = tf.Variable(tf.zeros([out_size,1]))
    wx_plus_b = tf.matmul(Weights,inputs)+bias
    
    if activation_function is None:
        return wx_plus_b
    else :
        return activation_function(wx_plus_b)
    


xs = tf.placeholder(tf.float32,[784,None])
ys = tf.placeholder(tf.float32,[10,None])


prediction = add_layer(xs,784,10,activation_function=tf.nn.softmax)

cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.1).minimize(cross_entropy)

init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)

def compute_accuracy(v_xs, v_ys):
    global prediction
    y_pre = session.run(prediction, feed_dict={xs: v_xs})
    # tf.argmax (y_pre,1 ) 返回每一行 下标最大的元素,1表示按行
    correct_prediction = tf.equal(tf.argmax(y_pre,0), tf.argmax(v_ys,0))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    result = session.run(accuracy, feed_dict={xs: v_xs, ys: v_ys})
    return result
for i in range(1000):
    batch_xs,batch_ys = mnist.train.next_batch(100)
    session.run(train_step,feed_dict={xs:np.transpose(batch_xs),ys:np.transpose(batch_ys)})
    if i%50==0:
        print("loss",session.run(cross_entropy,feed_dict={xs:np.transpose(batch_xs),ys:np.transpose(batch_ys)}))
        print(compute_accuracy(np.transpose(mnist.test.images), np.transpose(mnist.test.labels)))

输出

loss 148.512
0.1313
loss 48.9171
0.6285
loss 45.432
0.6736
loss 36.675
0.6946
loss 35.2471
0.7188
loss 33.0088
0.7398
loss 33.254
0.7498
loss 33.4405
0.7662
loss 30.5252
0.7718
loss 35.2746
0.7744
loss 33.5173
0.7945
loss 33.7244
0.7903
loss 30.1716
0.8109
loss 30.2718
0.8003
loss 29.682
0.8288
loss 31.7652
0.8243
loss 29.8357
0.8333
loss 30.9754
0.8324
loss 30.0272
0.835
loss 29.3583
0.8426

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值