import os
import cv2 as cv
import tensorflow as tf
STEPS=100000 #迭代次数
BATCH_SIZE=64 #训练批次
TRAIN_NUM=5000 #训练样本数量
TEST_NUM=1000 #测试样本数量
DISPLAY_ITER=500 #迭代多少次打印
TEST_ITER=5000 #迭代多少次测试
SNAPSHOT=20000 #迭代多少次保存
def train(train_path, val_path):
#训练集
list = os.listdir(train_path)
train_data=[]
train_label=[]
for filename in list:
filepath = '%s\\%s' % (train_path, filename)
img = cv.imread(filepath, 0)
if img is None:
continue
img = img / 255
rows,cols = img.shape
img = img.reshape((rows*cols))
train_data.append(img)#一维数据
labels = [0] * 10
labels[int(filename.split('_')[0])] = 1
train_label.append(labels)#数据标签
print('train data load!\n')
#测试集
list = os.listdir(val_path)
val_data=[]
val_label=[]
for filename in list:
filepath = '%s\\%s' % (val_path, filename)
img = cv.imread(filepath, 0)
if img is None:
continue
img = img / 255
rows,cols = img.shape
img = img.reshape((rows*cols))
val_data.append(img)#一维数据
labels = [0] * 10
labels[int(filename.split('_')[0])] = 1
val_label.append(labels)#数据标签
print('test data load!\n')
#定义网络
net_data_input = tf.placeholder(tf.float32, shape=(None, 28*28))
net_label_input = tf.placeholder(tf.float32, shape=(None, 10))
w1 = tf.Variable(tf.random_normal([28*28,500]))
b1 = tf.Variable(tf.random_normal([500]))
w2 = tf.Variable(tf.random_normal([500,10]))
b2 = tf.Variable(tf.random_normal([10]))
fc1 = tf.matmul(net_data_input, w1) + b1
relu1 = tf.nn.relu(fc1)
fc2 = tf.matmul(fc1, w2) + b2
#学习率
learning_rate = tf.train.exponential_decay(
0.1,
tf.Variable(0, trainable=False),
TRAIN_NUM/BATCH_SIZE,
0.99,
staircase=True)
#损失函数
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=fc2, labels=tf.argmax(net_label_input,1))
loss = tf.reduce_mean(ce)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
#测试
correct_prediction = tf.equal(tf.argmax(fc2, 1), tf.argmax(net_label_input, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
for i in range(STEPS+1):
start = (i*BATCH_SIZE) % TRAIN_NUM
end = start + BATCH_SIZE
sess.run(train_step, feed_dict={net_data_input:train_data[start:end],
net_label_input:train_label[start:end]})
if i % DISPLAY_ITER == 0:
loss_value = sess.run(loss, feed_dict={net_data_input:train_data, net_label_input:train_label})
print("[iter:%d] [lr:] [train loss:%g]" % (i, loss_value))
if i % TEST_ITER == 0:
accuracy_score = sess.run(accuracy, feed_dict={net_data_input:val_data, net_label_input:val_label})
print("\t[iter:%d] [lr:] [test accuracy:%g]" % (i, accuracy_score))
if i % SNAPSHOT == 0:
saver.save(sess, './model_%d' % i)
def test():
#定义网络
net_data_input = tf.placeholder(tf.float32, shape=(None, 28*28))
w1 = tf.Variable(tf.random_normal([28*28,500]))
b1 = tf.Variable(tf.random_normal([500]))
w2 = tf.Variable(tf.random_normal([500,10]))
b2 = tf.Variable(tf.random_normal([10]))
fc1 = tf.matmul(net_data_input, w1) + b1
relu1 = tf.nn.relu(fc1)
fc2 = tf.matmul(fc1, w2) + b2
#测试图片
img = cv.imread('D:\\8.jpg', 0)
ret, img = cv.threshold(img,128,255,cv.THRESH_OTSU)
img = img / 255
rows,cols = img.shape
img = img.reshape((rows*cols))
#
preValue = tf.arg_max(fc2, 1)
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('./')
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
preValue = sess.run(preValue, feed_dict={net_data_input:[img]})
print('preValue:%d' % preValue)
if __name__ == '__main__':
#train('E:\\[1]Paper\\Datasets\\MINST\\train', 'E:\\[1]Paper\\Datasets\\MINST\\query')
test()
tensorflow的入门可以参考中国大学MOOC上曹建老师的tensorflow笔记。
附训练图片格式: