- 如果你想保存在tensorflow上辛苦训练了很久的模型,随时地去使用它
- tensorflow将上述过程分成了两个部分:
训练和保存
提取和使用
- 训练部分为:加载训练数据,前向传播计算,代价函数评估,反向传播更新,保存
- 提取部分为:
加载测试数据(格式与训练数据保持一致)
前向传播计算(框架与训练部分一致)
- 我将以经典的手写数字识别数据集和lenet5框架举例
代码:训练与保存篇
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
tf.set_random_seed(1)
mnist = input_data.read_data_sets(r'MNIST_data')
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.int64,[None])
x_img = tf.reshape(x,[-1,1,28,28])
x_img = tf.transpose(x_img,perm=[0,2,3,1])
conv1_1 = tf.layers.conv2d(x_img,8,(3,3),padding='same',activation=tf.nn.relu,name='conv1_1')
pool1 = tf.layers.max_pooling2d(conv1_1,(2,2),(2,2),name='pool1')
conv2_1 = tf.layers.conv2d(pool1,