Tensorflow逻辑回归处理MNIST数据集

#1:导入所需的软件
import tensorflow as tf
'''
获取mnist数据放在当前文件夹下,利用input_data函数解析该数据集
train_img和train——label构成训练集,包含60000个手写体数字图片和对应的标签
test_img和test_label表示测试集,包含10000个样本和10000个标签
'''
from tensorflow.examples.tutorials.mnist import  input_data
mnist = input_data.read_data_sets('MNIST_data/',one_hot=True)

train_img = mnist.train.images
train_label = mnist.train.labels
test_img = mnist.test.images
test_label = mnist.test.labels


#3:在Tensorflow图中为训练数据集的输入x和标签y创建占位符

x = tf.compat.v1.placeholder(tf.float32,[None,784],name='X')
y = tf.compat.v1.placeholder(tf.float32,[None,10],name='Y')

#4:创建学习变量,权重和偏置

W = tf.compat.v1.Variable(tf.zeros([784,10]),name='W')
b = tf.compat.v1.Variable(tf.zeros([10]),name='b')

#5:创建逻辑回归模型。

y_hat = tf.nn.softmax(tf.matmul(x,W)+b)

#6: 损失(loss)函数

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=y_hat))

#7:采用Tensorflow GradientDescentOptimizer,学习率为0.01

optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.01).minimize(loss)

#预测结果
pred = tf.equal(tf.argmax(y_hat,1),tf.argmax(y,1))

#计算准确率
accuracy = tf.reduce_mean(tf.cast(pred,'float'))

#8:为变量进行初始化:
init = tf.global_variables_initializer()
training_epochs = 50
batch_size = 100
display_step = 5

with tf.compat.v1.Session() as sess:
    sess.run(init)
    summary_writer = tf.summary.FileWriter('graphs3',sess.graph)
    for epoch in range(training_epochs):
        loss_avg = 0
        num_of_batch = int(mnist.train.num_examples/batch_size)
        for i in range(num_of_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            feeds_train = {x:batch_xs,y:batch_ys}
            sess.run([optimizer,loss],feed_dict=feeds_train)
            loss_avg+=sess.run(loss,feed_dict=feeds_train)/num_of_batch
            #训练过程输出
            if epoch % display_step ==0:
                feeds_test = {x:mnist.test.images,y:mnist.test.labels}
                train_acc = sess.run(accuracy,feed_dict=feeds_train)
                test_acc = sess.run(accuracy,feed_dict=feeds_test)
                print()
            print('Epoch {0}:Loss {1}:train_acc:{2}:test_acc:{3}'.format(epoch,loss_avg,train_acc,test_acc))
    print("Done")


结果输入
在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值