一、介绍
基于tensorflow框架实现的Mnist数据分类。代码主要包括网络结构的搭建,训练超参数的导入和保存,损失函数地绘制等。不足之处是在网络结尾没用使用softmax函数,而直接使用了tanh输出了分类结果。下面请看代码的详细介绍
二、代码
- 导入必要的包文件,需要的包我直接通过pycharm导入的,能导入的原因是采用了anaconda3底下的python.exe,新建工程的时候,从外部导入
# 需要使用到的包文件
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
import argparse
import os
# 加上这一句能够使Plot绘制出来的图更精美
sns.set_style("whitegrid")
- 训练参数设置,详细介绍请看代码注释,主要采用了argparse,该模块的好处是直接可以在运行时修改参数,比如:python main.py --data_dir= "**"
parser = argparse.ArgumentParser(description="Network for image classification")
parser.add_argument('--data_dir', default='tem/data', help='Directory for training data') # Mnist数据集存放位置
parser.add_argument('--result_dir', default='tem/result') # 训练结果的存放
parser.add_argument('--model_dir', default='model/', help='the place of saving networks parameters') #训练参数的存放地址
parser.add_argument('--batch_size', default=32)
parser.add_argument('--print_loss', default=10) # 每隔10次迭代打印损失值
parser.add_argument('--plot_loss', default=100) # 每隔100次迭代绘制损失函数曲线
parser.add_argument('--learning_rate', default=0.001, type=float) # 学习率,不易设置过大
parser.add_argument('--n_iterations', default=10000, type=int) # 迭代次数
args = parser.parse_args() # 将--*的*传递给arg,调用时直接使用args.data_dir这样的结构
- 网络结构搭建
w_init = tf.random_normal_initializer(stddev=0.01) # 权重w初始化,标准差为0.01,平均值0
def network(x): # 激活函数都为relu,除了输出
layers1 = tf.layers.conv2d(x, 32, 3, 1, padding='same', activation=tf.nn.relu, kernel_initializer=w_init) # 32个卷积核,3x3卷积核大小,步长为1,padding为'same',即输出大小为input/stride,向上取整
layers2 = tf.layers.conv2d(layers1, 62, 3, 1, padding='same', activation=tf.nn.relu, kernel_initializer=w_init)
layers2_flatten = tf.contrib.layers.flatten(layers2) # 将layers2的输出"磨平",降低相关维度,以供全连接层工作
layers3 = tf.layers.dense(layers2_flatten, 200, activation=tf.nn.relu, kernel_initializer=w_init) # 200为全连接层单元个数,其它的痛卷积函数类似
output = tf.layers.dense(layers3, 10, activation=tf.nn.tanh, kernel_initializer=w_init) # 使用tanh作为输出,比sigmoid好,因为sigmoid是有0项,不利于网络训练
return output
- 训练网络,详细介绍看注释
def training():
input_x = tf.placeholder(tf.float32, [None, 28, 28, 1]) # 放置占位矩阵
label_y = tf.placeholder(tf.float32, [None, 10])
output_y = network(input_x) # 前向传播
loss = tf.reduce_sum(tf.square(label_y-output_y)) # 计算同便签损失
optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate).minimize(loss) # 使用Adam优化
init_all_v = tf.global_variables_initializer() # 张量初始化函数
sess = tf.InteractiveSession()
sess.run(init_all_v) # 实行张量初始化
saver = load_model(sess) # 导入之前训练过的参数,如果没有则打印出来
mnist = read_data_sets(args.data_dir, one_hot=True) # 往指定目录读取Mnist数据集
print('start training')
plot_loss = [] # 损失值缓存
for i in range(args.n_iterations):
batch_x, batch_y = mnist.train.next_batch(args.batch_size) # 读取Batch_size
batch_x = batch_x.reshape([args.batch_size, 28, 28, 1]) # 维度匹配
y = np.zeros([args.batch_size, 10]) # 下面的操作是因为我读到的标签是6,8,9直接对应的图片的数字,所以将这些数字向量化,以便训练
for j in range(args.batch_size):
k = batch_y[j].astype(np.int)
y[j, k] = 1.
batch_y = y
d_loss, _ = sess.run([loss, optimizer], feed_dict={input_x:batch_x, label_y:batch_y}) # 运行
plot_loss.append(d_loss)
if i % args.print_loss == 0 and i > 0:
print('Iteration is : %d, Loss is: %f' % (i, d_loss)) # 打印损失
if i % args.plot_loss == 0 and i > 0: # 绘图
plt.figure(figsize=(6*1.1618, 6))
plt.plot(range(len(plot_loss)), plot_loss)
plt.xlabel('iteration times')
plt.ylabel('lost')
plt.show()
if i % 500 == 0 and i > 0:
save_model(saver, sess, i)
- 模块的导入与存储
def save_model(saver, sess, step): # 存储模块
saver.save(sess, os.path.join(args.model_dir, "classification"), global_step=step)
def load_model(sess): # 导入模块
saver = tf.train.Saver()
checkpoint = tf.train.get_checkpoint_state(args.model_dir)
if checkpoint and checkpoint.model_checkpoint_path:
saver.restore(sess, checkpoint.model_checkpoint_path)
print("Successfully loaded:", checkpoint.model_checkpoint_path)
else:
print("Could not find any old weights!")
return saver
- 主函数
def main(_):
training()
if __name__ == "__main__":
tf.app.run()
从上往下黏贴就行,贴到IDE下就可以运行,还可以打印损失函数
鬼知道为什么下降这么快,,,