mnist数据集:包含7万张黑底白字手写数字图片,其中55000张作为训练集,5000张作为验证集,10000作为测试集。每张图片大小为28X28像素,图片中纯黑色像素值为0,纯白1。数据集的标签长度为10的一维数组,数组每个元素索引号表示对应数字出现的概率。
在将mnist数据集作为输入喂入神经网络时,需先将数据集中每张图片变成长度784一维数组,将该数组作为输入特征喂入神经网络。
from tensorflow.example.tutorials.minst import input_data
mnist=input_data.read_data_sets(‘.data/,one_hot=True)
第一个参数表示数据集存放路径,第二个参数表示数据集的存取形式。当第二个参数为true时,表示以独热码形式存取数据集。read_data_sets()函数运行时,会检查指定路径内是否已经有数据集,若指定路径没有数据集,则自动下载,并将mnist数据集分为训练集train,验证集validation和测试集test存放。
常用函数:
1),tf.get_collection(“”)表示从collection集合中取出全部变量生成一个列表。
2),tf.cast(x,dtype)表示将参数转为指定数据类型。
3),tf.equal()表示对比两个矩阵或向量的元素。若对应元素相等,则返回true,否则false
4),tf.reduce_mean(x,axis)表示求矩阵或张量指定维度的平均值。若不指定第二份参数,则在所有元素中去平均值,指定第二个参数为0,则每一列求平均值,指定第二个参数为1,则每一行求平均值。
5),tf.argmax(x,axis)表示返回指定维度axis下,参数x中最大索引号。
6),os.path.join()表示把参数字符串按照路径命名规则拼接。、
例如:
import os
os.path.join(‘/hello’,’/good/’,’boy/’)
输出结果:/hello/good/boy/
7),字符串.split()表示按照指定“拆分符”对字符串拆分,返回拆分列表。
例如:
‘./model/mnist_model-1001’.split(‘/’)[-1].split(‘-1’)[-1]
在该例子中,共进行两次拆分。第一个拆分’/’,返回拆分列表,并提取拆分列表,并提取列表索引为-1的元素即倒数第一个元素,第二个拆分符为‘-’返回拆分列表,并提取列表索引为-1的元素即倒数第一个元素,故函数返回值为1001.
8),tf,Graph().as_default()函数表示当前图设置为默认图,并返回一个上下文管理器。
例如:
with tf.Graph().as_default() as g,表示将在Graph()内定义的节点加入到计算图g中。
神经网络模型保存:
在反向传播过程中,一般会间隔一定轮数保存一次神经网络模型,并产生三个文件(保存当前图结构的.meta文件、保存当前参数名的.index文件、保存当前参数的.data文件)
saver=tf.train.Saver()
with tf.Session() as sess:
for i in range(steps):
if i %轮数==0:
saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)
上述代码表示,神经网络没循环规定的轮数,将神经网络模型中所有的参数等信息保存到指定的路径中,并在存放模型的文件夹名称中注明保存模型时的训练轮数。
神经网络模型的加载:
在测试网络效果时,需要将训练好的神经网络模型加载:
with tf.Session() as sess:
ckpt=tf.train