先讲一下Tensorflow模型的组成,如果是通过tf.train.Saver()保存的模型,那么会生成3种文件:
- .meta是网络结构,就是深度学习网络的那些隐层和全连接层等的定义
- checkpoint是记录输出模型的checkpoint
- 剩下的文件保存的是模型网络中的具体的参数
下面看下训练代码:
代码很简单,先定义两个Tensor的变量w1和w2,b1是一个常量2,然后定义一个字典,其中w1是4,w2是8。接着定义op,op指的是Tensorflow计算符,tf.add将w1和w2相加,然后通过tf.multiply将w1和w2相加的结果乘以2。接着生成全局的参数tf.global_variables_intitializer(),就是初始化参数,取第1000次的checkpoint把模型保存为my_test_model。这个代码的意思是输入w1和w2,然后模型会返回(w1+w2)*b1的结果,b1是常量,等于2。
运行后模型就保存下来,下面看下怎么调用:
通过import_meta_graph这个函数加载训练时的网络结构,然后用restore方法加载网络结构中的权重,到了这步预测模型就加载好了。接着设置一组预测值,使得w1=6,w2=7。获取计算op,也就是当初训练的时候定义的op名称‘op_to_restore’。然后就可以把数据传到op里进行计算,生成的结果为(6+7)*2=26。
彩蛋:如果想基于已有的模型refine,可以在原有模型上增加计算op,参考第二张图注释部分,可以自己试下。