关于生成.pb的作用:
.pb文件为训练模型最终生成的文件。我们使用它的目的是为了不再Android项目中进行训练而是直接套用训练的成功。
关于怎么训练模型:
本示例基于mnist手写识别项目。关于mnist手写识别训练模型的部分请参考tensorflow极客学院的中文教程中的初识mnist和深入mnist,因此关于怎么训练模型再此不在赘述。
关于怎么导出模型:
给我们导出模型的方法来自于我发现的TensorflowMnistAdroid dem。前面训练模型和极客学院的官网示例代码大同小异。重点在于怎么导入模型,现分析如下:
第一步:
保存训练结果。也就是保存训练出来的变量的值。
相关代码:
# Store variable
_W_conv1 = W_conv1.eval(sess)
_b_conv1 = b_conv1.eval(sess)
_W_conv2 = W_conv2.eval(sess)
_b_conv2 = b_conv2.eval(sess)
_W_fc1 = W_fc1.eval(sess)
_b_fc1 = b_fc1.eval(sess)
_W_fc2 = W_fc2.eval(sess)
_b_fc2 = b_fc2.eval(sess)
sess.close()
其中,eval的用法如下:
eval()函数的大致用法是给session中的张量Evaluates(给。。。评估,定值)调用这个方法会执行输入变量的所有操作。个人感觉就是把session中的张量保存下来。详细的用法请参照附录api使用。
第二步:
新建一个图,用于导出训练结果即.pb文件。
关于怎么新建图,也与前面训练模型中使用的方法大同小异。导出.pb文件关键代码在于:
graph_def = g_2.as_graph_def()
tf.train.write_graph(graph_def, export_dir, 'expert-graph.pb', as_text=False)
export_dir变量为前面创建的存放文件的目录.
g_2为新创建的图,as_graph_def()的官方解释的作用是:Returns a serialized GraphDef representation of this graph。tf.train.write_graph的作用是如其名一样是输出一个模型到一个文件。