介绍
使用 tensorflow 实现 LeNet-5 网络,用于手写数字集的识别。
代码
'''
Created on 2018年4月27日
@author: wangs0622
'''
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
from tensorflow.contrib.layers import flatten
filepath = r"G:\python\datasets\mnist"
mnist = input_data.read_data_sets(filepath, reshape=False)
x_train, y_train = mnist.train.images, mnist.train.labels
x_validation, y_validation = mnist.validation.images, mnist.validation.labels
x_test, y_test = mnist.test.images, mnist.test.labels
assert(len(x_train) == len(y_train))
assert(len(x_validation) == len(y_validation))
assert(len(x_test) == len(y_test))
print("Image shape = {}".format(x_train[0].shape))
print("Training set size = {}".format(len(x_train)))
print("Validation set size = {}".format(len(x_validation)))
print("Test set size = {}".format(len(x_test)))
x_train = np.pad(x_train, [(0,0),(2,2),(2,2),(0,0)], "constant")
x_validation = np.pad(x_validation, [(0,0),(2,2),(2,2),(0,0)], "constant")
x_test = np.pad(x_test, [(0,0),(2</