tensorflow(3)Logistic regression

本文介绍了一个基于TensorFlow的手写数字识别模型。该模型使用了MNIST数据集进行训练,并通过Softmax回归实现了对手写数字的有效识别。文章详细记录了从数据加载到模型训练的全过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

MNIST手写数字识别
tensorflow基本概念

# coding=utf-8
import tensorflow as tf
import input_data
mnist = input_data.read_data_sets('data/',one_hot=True)
train_img   = mnist.train.images
train_label = mnist.train.labels
test_img    = mnist.train.images
test_label  = mnist.train.labels

print ("mnist loaded")

'''
saver = tf.train.Saver()
with tf.Session() as sess:
    save_path = saver.save(sess,"/home/aitab/PycharmProjects/mnist2")
    print("model saved in file:",save_path)
'''
print (train_img.shape)
print (train_label.shape)
print (test_img.shape)
print (test_label.shape)
#print (trainimg)
print (train_label[0])

x = tf.placeholder("float",[None,784])
y = tf.placeholder("float",[None,10])

W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

# model
actv = tf.nn.softmax(tf.matmul(x,W) + b)
# loss function
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
# optimezer
learning_rate = 0.01
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# prediction
pred = tf.equal(tf.arg_max(actv,1),tf.arg_max(y,1))
# accuracy
accr = tf.reduce_mean(tf.cast(pred,"float"))
# initalizer
init = tf.global_variables_initializer()

train_epochs = 50
batch_size   = 100
step         = 5
sess = tf.Session()
sess.run(init)

for epoch in range(train_epochs):
    avg_cost = 0.
    batch_num = int(mnist.train.num_examples/batch_size)
    for i in range(batch_num):
        # 每次丢入一个batch
        batch_xs,batch_ys = mnist.train.next_batch(batch_size)
        sess.run(optm,feed_dict={x:batch_xs,y:batch_ys})
        feeds = {x:batch_xs,y:batch_ys}
        avg_cost += sess.run(cost,feed_dict=feeds)/batch_num

    # display
    if epoch % step == 0:
        feeds_train = {x:batch_xs,y:batch_ys}
        feeds_test  = {x:mnist.test.images,y:mnist.test.labels}
        train_acc   = sess.run(accr,feed_dict=feeds_train)
        test_acc    = sess.run(accr,feed_dict=feeds_test)
        print("Epoch: %03d/%03d cost:%.9f train_acc:%.3f test_acc:%.3f"% (epoch,train_epochs,test_acc,test_acc))
   
print ("DONE")

OUT:
Epoch: #03d/000 cost:50.000000000 train_acc:0.852 test_acc:0.852
Epoch: #03d/005 cost:50.000000000 train_acc:0.895 test_acc:0.895
Epoch: #03d/010 cost:50.000000000 train_acc:0.905 test_acc:0.905
Epoch: #03d/015 cost:50.000000000 train_acc:0.909 test_acc:0.909
Epoch: #03d/020 cost:50.000000000 train_acc:0.913 test_acc:0.913
Epoch: #03d/025 cost:50.000000000 train_acc:0.914 test_acc:0.914
Epoch: #03d/030 cost:50.000000000 train_acc:0.915 test_acc:0.915
Epoch: #03d/035 cost:50.000000000 train_acc:0.917 test_acc:0.917
Epoch: #03d/040 cost:50.000000000 train_acc:0.918 test_acc:0.918
Epoch: #03d/045 cost:50.000000000 train_acc:0.918 test_acc:0.918
DONE

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值