lenet-5实现mnist手写数字识别
关于lenet-5模型已经mnist我就不多复述了,网上很多关于这些的简介,所以我就直接上代码了。
新手入门,代码也是很不完善,以下代码仅限参考
1、mnist_lenet5_forward.py (前向传播)
#coding:utf-8
import tensorflow as tf
IMAGE_SIZE = 28
NUM_CHANNELS = 1
CONV1_SIZE = 5
CONV1_KERNEL_NUM = 32
CONV2_SIZE = 5
CONV2_KERNEL_NUM = 64
FC_SIZE = 512
OUTPUT_NODE = 10
def get_weight(shape, regularizer):
w = tf.Variable(tf.truncated_normal(shape,stddev=0.1))
if regularizer != None: tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
return w
def get_bias(shape):
b = tf.Variable(tf.zeros(shape))
return b
def conv2d(x,w):
return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
def forward(x, train, regularizer):
conv1_w = get_weight([CONV1_SIZE, CONV1_SIZE, NUM_CHANNELS, CONV1_KERNEL_NUM], regularizer)
conv1_b = get_bias([CONV1_KERNEL_NUM])
conv1 = conv2d(x, conv1_w)
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_b))
pool1 = max_pool_2x2(relu1)
conv2_w = get_weight([CONV2_SIZE, CONV2_SIZE, CONV1_KERNEL_NUM, CONV2_KERNEL_NUM],regularizer)
conv2_b = get_bias([CONV2_KERNEL_NUM])
conv2 = conv2d(pool1, conv2_w)
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_b))
pool2 = max_pool_2x2(relu2)
pool_shape = pool2.get_shape().as_list()
nodes = pool_shape[1] * pool_shape[2] * pool_shape[3]
reshaped = tf.reshape(pool2, [pool_shape[0], nodes])
fc1_w = get_weight([nodes, FC_SIZE], regularizer)
fc1_b = get_bias([FC_SIZE])
fc1 = tf.nn.relu(tf.matmul(reshaped, fc1_w) + fc1_b)
if train: fc1 = tf.nn.dropout(fc1, 0.5)
fc2_w = get_weight([FC_SIZE, OUTPUT_NODE], regularizer)
fc2_b = get_bias([OUTPUT_NODE])
y = tf.matmul(fc1, fc2_w) + fc2_b
return y
2、mnist_lenet5_generateds.py (生成tfRecord文件)
# coding:utf-8
import tensorflow as tf
import numpy as np
from PIL import Image
import os
import glob
image_train_path = './fashion_mnist_png/fashion_mnist_train/'
# label_train_path = './fashion_mnist_png/mnist_train_jpg_60000.txt'
tfRecord_train = './fashion_data/fashion_mnist_train.tfrecords'
image_test_path = './fashion_mnist_png/fashion_mnist_test/'
# label_test_path = './fashion_mnist_png/mnist_test_jpg_10000.txt'
tfRecord_test = './fashion_data/fashion_mnist_test.tfrecords'
data_path = './fashion_data'
resize_height = 28
resize_width = 28
def write_tfRecord(tfRecordName, image_path): #