上回已经可以得到一个清晰的模型结构,但是还是不够满意,为什么呢,因为预测的时候不需要dropout层,所以想修改接口,直接去除dropout层。
以下方法自己想的,可能有别的更好的方法:
首先,训练的时候得把每层的参数起好名字,否则名字都自动起的,怎么修改网络,直接就懵逼了。
#定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
keep_prob=tf.placeholder(tf.float32)
lr = tf.Variable(0.001, dtype=tf.float32)
#创建一个简单的神经网络
W1 = tf.Variable(tf.truncated_normal([784,500],stddev=0.1),name="w1")
b1 = tf.Variable(tf.zeros([500])+0.1,name="b1")
L1 = tf.nn.tanh(tf.matmul(x,W1)+b1,name="net1")
L1_drop = tf.nn.dropout(L1,keep_prob)
W2 = tf.Variable(tf.truncated_normal([500,300],stddev=0.1),name="w2")
b2 = tf.Variable(tf.zeros([300])+0.1,name="b2")
L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2,name="net2")
L2_drop = tf.nn.dropout(L2,keep_prob)
W3 = tf.Variable(tf.truncated_normal([300,10],stddev=0.1),name="w3")
b3 = tf.Variable(tf.zeros([10])+0.1,name="b3")
prediction = tf.nn.softmax(tf.matmul(L2_drop,W3)+b3,name="net3")
和以前一样把训练结果保存为.ckpt文件
第二步,自己新写一个网络结构,手动去除dropout层,不用训练,当然也不需要最终标签,acc等网络结构,新网络参数的名字和训练网络一致(不一致load的时候得指定关系,得累死),然后导入ckpt文件,保存成pb文件
#定义修改后的网络结构
x = tf.placeholder(tf.float32,[None,784])
W1 = tf.Variable(tf.truncated_normal([784,500],stddev=0.1),name="w1")
b1 = tf.Variable(tf.zeros([500])+0.1,name="b1")
L1 = tf.nn.tanh(tf.matmul(x,W1)+b1,name="net1")
W2 = tf.Variable(tf.truncated_normal([500,300],stddev=0.1),name="w2")
b2 = tf.Variable(tf.zeros([300])+0.1,name="b2")
L2 = tf.nn.tanh(tf.matmul(L1,W2)+b2,name="net2")
W3 = tf.Variable(tf.truncated_normal([300,10],stddev=0.1),name="w3")
b3 = tf.Variable(tf.zeros([10])+0.1,name="b3")
prediction = tf.nn.softmax(tf.matmul(L2,W3)+b3,name="net3")
saver = tf.train.Saver() #声明saver用于保存模型
import os
ckpt_dir = "./pb_dir"
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
with tf.Session() as sess:
saver.restore(sess,"ckpt_dir/Test2.ckpt-19")
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['net3'])
with tf.gfile.FastGFile(ckpt_dir+'/Test2.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
第三步,和以前新开个代码,直接导入pb文件,写入tensorboard
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework import graph_util
#载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
from tensorflow.python.platform import gfile
with tf.gfile.FastGFile('pb_dir/Test2.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
writer = tf.summary.FileWriter("logs/", sess.graph)
NetOut=sess.run("net3:0",feed_dict={"Placeholder:0":mnist.test.images})
import numpy as np
Prediction = NetOut.argmax(axis=1)#找到最大值位置
TestLabels = mnist.test.labels.argmax(axis=1)#找到最大值位置
err = 0
for i in range(Prediction.shape[0]):
if (Prediction[i]!=TestLabels[i]):
err = err+1
err=err/Prediction.shape[0]
acc = 1-err
print("Acc=",acc)
打开tensorboard可以看到修改后的网络结构图:没有dropout了哦!
这样就真正达到了随心所欲,自由自在