tensorflow实战二

tensorflow实战二

概述

本篇博文利用mnist数据集,利用最简单的单层神经网络识别手写数字。

代码实现

本博文使用了最简单的单层神经网络结构,输入层有784个输入神经元,输出层有10个神经元,采用one_hot形式读入数据。
话不多说,直接上代码

#-*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

#载入数据集
minist = input_data.read_data_sets("minist",one_hot=True)

#定义批次大小,数据量太大,采用随机梯度下降法进行批次训练,一个批次暂定50张训练数据
batch_size = 50
batch_n = minist.train.num_examples//batch_size

#定义两个占位符
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])

#创建一个简单神经网络
w1 = tf.Variable(tf.zeros([784,10]))
b1 = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,w1) + b1) #使用softmax作为激活函数

loss = tf.reduce_mean(tf.square(y-prediction)) #定义代价函数
train = tf.train.GradientDescentOptimizer(0.05).minimize(loss) #定义优化器

#计算准确率
correct_prediction = tf.equal(tf.arg_max(y,1),tf.arg_max(prediction,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

sess = tf.Session()
sess.run(tf.global_variables_initializer())
for step in range(50):
    for _ in range(batch_n):
        batch_xs,batch_ys = minist.train.next_batch(batch_size) #每次取50张图片
        sess.run(train,feed_dict={x:batch_xs,y:batch_ys})
    acc = sess.run(accuracy,feed_dict={x:minist.test.images,y:minist.test.labels})
print "Iter"+str(step)+":"+str(acc)

sess.close()

结果和小结

上述神经网络的结构是最简单的网络,最终得到了正确率为0.9158 。大家可以从参数初始化、批次调节、迭代次数、激活函数、增加隐藏层这几种方法进行改进。
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值