使用tensorflow训练和测试手写字体识别

import os
import cv2 as cv
import tensorflow as tf

STEPS=100000             #迭代次数
BATCH_SIZE=64           #训练批次
TRAIN_NUM=5000          #训练样本数量
TEST_NUM=1000           #测试样本数量
DISPLAY_ITER=500        #迭代多少次打印
TEST_ITER=5000          #迭代多少次测试
SNAPSHOT=20000          #迭代多少次保存

def train(train_path, val_path):

    #训练集
    list = os.listdir(train_path)
    train_data=[]
    train_label=[]
    
    for filename in list:
        filepath = '%s\\%s' % (train_path, filename)
        img = cv.imread(filepath, 0)
        if img is None:
            continue
        img = img / 255
        rows,cols = img.shape
        img = img.reshape((rows*cols))
        train_data.append(img)#一维数据
        labels = [0] * 10
        labels[int(filename.split('_')[0])] = 1
        train_label.append(labels)#数据标签

    print('train data load!\n')

    #测试集
    list = os.listdir(val_path)
    val_data=[]
    val_label=[]

    for filename in list:
        filepath = '%s\\%s' % (val_path, filename)
        img = cv.imread(filepath, 0)
        if img is None:
            continue
        img = img / 255
        rows,cols = img.shape
        img = img.reshape((rows*cols))
        val_data.append(img)#一维数据
        labels = [0] * 10
        labels[int(filename.split('_')[0])] = 1
        val_label.append(labels)#数据标签

    print('test data load!\n')

    #定义网络
    net_data_input = tf.placeholder(tf.float32, shape=(None, 28*28))
    net_label_input = tf.placeholder(tf.float32, shape=(None, 10))

    w1 = tf.Variable(tf.random_normal([28*28,500]))
    b1 = tf.Variable(tf.random_normal([500]))
    w2 = tf.Variable(tf.random_normal([500,10]))
    b2 = tf.Variable(tf.random_normal([10]))

    fc1 = tf.matmul(net_data_input, w1) + b1
    relu1 = tf.nn.relu(fc1)
    fc2 = tf.matmul(fc1, w2) + b2

    #学习率
    learning_rate = tf.train.exponential_decay(
        0.1,
        tf.Variable(0, trainable=False),
        TRAIN_NUM/BATCH_SIZE,
        0.99,
        staircase=True)

    #损失函数
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=fc2, labels=tf.argmax(net_label_input,1))
    loss = tf.reduce_mean(ce)
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

    #测试
    correct_prediction = tf.equal(tf.argmax(fc2, 1), tf.argmax(net_label_input, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    #保存模型
    saver = tf.train.Saver()

    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        for i in range(STEPS+1):
            start = (i*BATCH_SIZE) % TRAIN_NUM
            end = start + BATCH_SIZE

            sess.run(train_step, feed_dict={net_data_input:train_data[start:end], 
                                            net_label_input:train_label[start:end]})

            if i % DISPLAY_ITER == 0:
                loss_value = sess.run(loss, feed_dict={net_data_input:train_data, net_label_input:train_label})
                print("[iter:%d] [lr:] [train loss:%g]" % (i, loss_value))

            if i % TEST_ITER == 0:
                accuracy_score = sess.run(accuracy, feed_dict={net_data_input:val_data, net_label_input:val_label})
                print("\t[iter:%d] [lr:] [test accuracy:%g]" % (i, accuracy_score))

            if i % SNAPSHOT == 0:
                saver.save(sess, './model_%d' % i)

def test():

    #定义网络
    net_data_input = tf.placeholder(tf.float32, shape=(None, 28*28))

    w1 = tf.Variable(tf.random_normal([28*28,500]))
    b1 = tf.Variable(tf.random_normal([500]))
    w2 = tf.Variable(tf.random_normal([500,10]))
    b2 = tf.Variable(tf.random_normal([10]))

    fc1 = tf.matmul(net_data_input, w1) + b1
    relu1 = tf.nn.relu(fc1)
    fc2 = tf.matmul(fc1, w2) + b2

    #测试图片
    img = cv.imread('D:\\8.jpg', 0)
    ret, img = cv.threshold(img,128,255,cv.THRESH_OTSU)
    img = img / 255
    rows,cols = img.shape
    img = img.reshape((rows*cols))

    #
    preValue = tf.arg_max(fc2, 1)
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state('./')
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            preValue = sess.run(preValue, feed_dict={net_data_input:[img]})

    print('preValue:%d' % preValue)

if __name__ == '__main__':
    #train('E:\\[1]Paper\\Datasets\\MINST\\train', 'E:\\[1]Paper\\Datasets\\MINST\\query')
    test()

tensorflow的入门可以参考中国大学MOOC上曹建老师的tensorflow笔记。

 

附训练图片格式:

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值