这是TensorFlow的测试用例之一,本文将逐行予以解析。
本例使用手写数字识别MNIST,MNIST是机器学习领域的一个经典问题。指的是让机器查看一系列大小为28x28像素的手写数字灰度图像,并判断这些图像代表0-9中的哪一个数字。
- 准备数据
# 获取参数
FLAGS, b = parser.parse_known_args()
# load data
data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)
其中,各种参数的配置如下:
parser = argparse.ArgumentParser() # argparse是python3.*中处理参数的标准库
# 导入具体的参数
parser.add_argument(
'--learning_rate',
type=float,
default=0.01,
help='Initial learning rate.'
)
parser.add_argument(
'--max_steps',
type=int,
default=2000,
help='Number of steps to run trainer.'
)
parser.add_argument(
'--hidden1',
type=int,
default=128,
help='Number of units in hidden layer 1.'
)
parser.add_argument(
'--hidden2',
type=int,
default=32,
help='Number of units in hidden layer 2.'
)
parser.add_argument(
'--batch_size',
type=int,
default=100,
help='Batch size. Must divide evenly into the dataset sizes.'
)
parser.add_argument(
'--input_data_dir',
type=str,
default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
'tensorflow/mnist/input_data'),
help='Directory to put the input data.'
)
parser.add_argument(
'--log_dir',
type=str,
default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
'tensorflow/mnist/logs/fully_connected_feed'),
help='Directory to put the log data.'
)
parser.add_argument(
'--fake_data',
default=False,
help='If true, uses fake data for unit testing.',
action='store_true'
)
- 构建图
TensorFlow是通过构建图来运行操作,对应于多层前馈神经网络,需要构筑2个隐藏层,mnist.inference就是这样的函数,其hidden_layer1是 tf.nn.relu(tf.matmul(images, weights) + biases),用激活函数relu来运算 wx+