13.1 典型的模型保存方法
- train.Saver类是Tensorflow1.x自己OMG提供的用于保存和还原一个神经网络模型的低阶API。
import tensorflow as tf
import numpy as np
a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")
result=a+b
init_op=tf.initialize_all_variables()
saver=tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
该方法会产生4个文件,其中
-
checkpoint文件是一个文本文件,保存了一个目录下所有模型文件列表。checkpoint文件会被自动更新。
-
model.ckpt.data-00000-of-00001文件保存了Tensorflow程序中每一个变量的取值
-
model.ckpt.index文件保存了每一个变量的名称,是一个string-string的table,其中tabe的key值为tensor名,value值为BundleEntryProto
-
model.ckpt.meta文件保存了计算图的结构,或者说是神经网络的结构
-
restore()函数需要在模型参数恢复前定义计算图上的所有运算,并且变量名需要与模型中存储的变量名一致,这样就可以将变量的值通过已保存的模型加载进来。
import tensorflow as tf
import numpy as np
a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")
result=a+b
saver=tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,"/home/xxy/model/model/ckpt")
print(sess.run(result))
- import_meta_graph()直接加载已经持久化的计算图。其输入参数为一个.meta文件的路径。它返回一个Saver实例,在调用restore()函数就可以回复其参数了。
import tensorflow as tf
meta_graph=tf.train.import_meta_graph("/home/xxy/model/model.ckpt.meta")
with tf.Session() as sess:
meta_graph.restore(sess,"/home/xxy/model/model.ckpt")
print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
- 保存和加载部分变量
import tensorflow as tf
a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")
result=a+b
saver=tf.train.Saver([a])
with tf.Session() as sess:
saver.restore(sess,"/home/xxy/model/model.ckpt")
print(sess.run(a))
保存部分变量也可以通过在声明train.Saver类的同时提供一个列表的方式来指定。
- 保存或加载时给变量重新命名。
import tensorflow as tf
a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a2")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b2")
result=a+b
saver=tf.train.Saver({"a":a,"b":b})
with tf.Session() as sess:
saver.restore(sess,"/home/xxy/model/model/ckpt")
print(sess.run(result))
- 滑动平均变量保存方式
writer
import tensorflow as tf
a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")
##滑动平均变量定义
averages_class=tf.train.ExponentialMovingAverage(0.99)
averages_op=averages_class.apply(tf.all_variables())
for variables in tf.global_variables():
print(variables.name)
init_op=tf.global_variables_initializer()
saver=tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
#assign()更新变量值
sess.run(tf.assign(a,10))
sess.run(tf.assign(b,5))
sess.run(averages_op)
saver.save(sess,"/home/xxy/model/model2.ckpt")
print(sess.run([a,averages_class.average(a)]))
print(sess.run([b,averages_class.average(b)]))
reader
import tensorflow as tf
a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")
##滑动平均变量定义
averages_class=tf.train.ExponentialMovingAverage(0.99)
saver=tf.train.Saver({"a/ExponentialMovingAverage":a,
"b/ExponentialMovingAverage":b})
#train.ExponentialMovingAverage 提供的variables_to_restore()函数直接生成上面代码中提供的字典,与上面的代码功能相同
'''
saver=tf.train.Saver(averages_class.variables_to_restore())
'''
with tf.Session() as sess:
sess.restore(sess,"/home/xxy/model/model2.ckpt")
print(sess.run([a,b]))
print(averages_class.variables_to_restore())
13.2 模型持久化的原理
1.model.ckpt.mate文件
model.ckpt.mate文件存储的是Tensorflow程序的元图数据。就是计算图的节点信息。元图数据的存储格式为MetaGraphDef。
message MetaGraphDef{
MetaInfoDef meta_info_def=1;
GraphDef graph_def=2;
SaveDef saver_def=3;
map<string,CollectionDef> collection_def=4;
map<string,SIgnatureDef> signature_def=5;
};
import tensorflow as tf
a=tf.Variable(constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(constant([3.0,4.0],shape=[2]),name="b")
result=a+b
saver=tf.train.Saver()
asver.export_meta_graph("/home/xxy/model_ckpt_meta_json",as_text=True)
- meta_info_def属性
message MetaInfoDef{
string meta_graph_version=1;
OpList stripped_op_list = 2;
google.protobuf.Any_any_info=3;
repeated string tags=4;
};
其中stripped_op_list记录了计算图中用到的所有运算方法信息。记录了OpDef型的op属性
attr{
name:"T"
type:"type"
allow_values{
list{
type:DT_HALF
type:DT_FLOAT
type:DT_DOUBLE
type:DT_UINT8
type:DT_INT8
type:DT_INT16
type:DT_INT32
type:DT_INT64
type:DT_COMPLEX64
tpe:DT_COMPLEX128
type:DT_STRING
}
}
}
- graph_def属性
message GrapDef{
repeated NodeDef node=1;
VersionDef versions=4;
};
message NodeDef{
string name=1;
string op=2;
repeated string input=3;
string device =4;
map<string,AttrValue> attr =5;
};
- saver_def属性
message SaveDef{
string filename_tensor_name =1;
string save_tensor_name = 2;
string restore_op_name =3;
int32 max_to_keep=4;
bool shared =5;
float keep_checkpoint_every_n_hours=6;
enum CheckpointFormatVersion{
LEGACY=0;
V1=1;
V2=2;
}
CheckpointFormatVersion=7;
};
save_def{
filename_tensor_name:"save/Const:0"
save_tensor_name:"save/control_dependency:0"
restore_op_name:"save/restore_all"
max_to_keep:5
keep_checkpoint_every_n_hours:10000.0
version:V2
};
- collection_def属性
message CollectionDef{
message NodeList{
repeated string value=1;
}
message BytesList{
repeated bytes value=1;
}
message Int64List{
repeated int64 value=1[packed=true];
}
message {
repeated floFloatListat value =1;
}
message AnyList{
repeated google.protocolbuf.Any value =1;
}
oneof kind{
NodeList node_list=1;
BytesList bytes_list=2;
Int64List int64_list=3;
FLoatList float_list=4;
AnyList any_list=5;
}
};
2.从.index和.data文件中读取变量的值
import tensorflow as tf
reader=tf.train.NewCheckpointReader("/home/xxy/model/model.ckpt")
all_variables=reader.get_variable_to_shape_map()
print(all_variables)
for variale_name in all_variables:
print(variale_name,"shape is:",all_variables[variale_name])
print("Value for variable a is : ",reader.get_tensor("a"))
print("Value for variable b is : ",reader.get_tensor("b"))
13.3 在Tensorflow 2.0中实现模型保存
save
import tensorflow as tf
from tensorflow.keras import layers
(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()
train_images=train_images.reshape(60000,748).astype('float32')/255
test_images=test_images.reshape(10000,748).astype('float32')/255
inputs=tf.keras.Input(shape=(784,),name='digits')
x=layers.Dense(500,activation='relu',name='dense_1')(inputs)
outputs=layers.Dense(10,activation='softmax',name='predictions')(x)
mlpmodel=tf.keras.Model(inputs=inputs,outputs=outputs,name='MLPModel')
model.summary()
mlpmodel.comploe(loss='sparse_categorical_crossentropy',optimizer=tf.keras.optimizers.SGD(),metrics=['accuracy'])
mlpmodel.fix(x=train_images,y=train_labels,epochs=10,batch_size=100,validation_data=(test_images,test_labels))
mlpmodel.save("/home/xxy/model/model.h5")
read
import tensorflow as tf
import numpy as np
(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()
test_images=test_images.reshape(10000,784).astype('float32')/255
load_mlpmodel=tf.keras.models.load_model("/home/xxy/model/model.h5")
class_name=['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankleboot']
predictions=load_mlpmodel.predict(test_images)
for i in range(100):
print("predict class result is : ",class_name[np.argmax(predictions[i])])
print("crrect class result is : ", class_name[test_labels[i]])
save
import tensorflow as tf
from tensorflow.keras import layers
(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()
train_images=train_images.reshape(60000,748).astype('float32')/255
test_images=test_images.reshape(10000,748).astype('float32')/255
inputs=tf.keras.Input(shape=(784,),name='digits')
x=layers.Dense(500,activation='relu',name='dense_1')(inputs)
outputs=layers.Dense(10,activation='softmax',name='predictions')(x)
mlpmodel=tf.keras.Model(inputs=inputs,outputs=outputs,name='MLPModel')
model.summary()
mlpmodel.comploe(loss='sparse_categorical_crossentropy',optimizer=tf.keras.optimizers.SGD(),metrics=['accuracy'])
mlpmodel.fix(x=train_images,y=train_labels,epochs=10,batch_size=100,validation_data=(test_images,test_labels))
tf.keras.experimental.export_saved_model(mlpmodel,"/home/xxy/model/")
read
import tensorflow as tf
import numpy as np
(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()
test_images=test_images.reshape(10000,784).astype('float32')/255
load_mlpmodel=tf.keras.experimental.load_from_saved_model("/home/xxy/model/")
class_name=['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankleboot']
predictions=load_mlpmodel.predict(test_images)
for i in range(100):
print("predict class result is : ",class_name[np.argmax(predictions[i])])
print("crrect class result is : ", class_name[test_labels[i]])
config=mlpmode.get_config()
load_mlpmodel=tf.keras.from_config(config)
save_weight
import tensorflow as tf
from tensorflow.keras import layers
(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()
test_images=test_images.reshape(10000,784).astype('float32')/255
class MLPModel(tf.keras.Model):
def __init__(self,name=None):
super(MLPModel,self).__init__(name=name)
self.dense=layers.Dense(500,activation='relu',name='dense')
self.dense_1=layers.Dense(10,activation='softmax',name='dense_1')
def call(self,inputs):
x=self.dense(inputs)
return self.dense_1(x)
mlpmodel=MLPModel()
mlpmodel.compile(loss='sparse_categorical_crossentropy',optimzer=tf.keras.optimizers.SGD())
history=mlpmodel.fit(train_images,train_labels,batch_size=100,epochs=10)
mlpmodel.save_weights("/home/xxy/model/",save_format='tf')
read_weight
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()
test_images=test_images.reshape(10000,784).astype('float32')/255
class MLPModel(tf.keras.Model):
def __init__(self,name=None):
super(MLPModel,self).__init__(name=name)
self.dense=layers.Dense(500,activation='relu',name='dense')
self.dense_1=layers.Dense(10,activation='softmax',name='dense_1')
def call(self,inputs):
x=self.dense(inputs)
return self.dense_1(x)
new_model=MLPModel()
new_model.compile(loss='sparse_categorical_crossentropy',optimzer=tf.keras.optimizers.SGD())
new_model.load_weights('/home/xxy/model/')
class_name=['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankleboot']
new_predictions=new_model.predict(test_images)
for i in range(100):
print("predict class result is : ",class_name[np.argmax(new_predictions[i])])
print("crrect class result is : ", class_name[test_labels[i]])
13.4 PB文件
writer
import tensorflow as tf
tensorflow/python/framework/graph_util.py
fromg tensorflow.python.framework import graph_util
a=tf.Variable(tf.constant([1.0,shape=[1]),name="a")
b=tf.Variable(tf.constant([3.0,shape=[1]),name="b")
result=a+b
init_op=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
graph_def=tf.get_defalut_graph().as_graph_def()
output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,['add'])
with tf.gfile.GFile("/home/xxy/model/model.pb","wb") as f:
f.write(output_graph_def.SerializeToString())
reader
import tensorflow as tf
fromg tensorflow.python.platform import gfile
with tf.Session() as sess:
with tf.gfile.FatGFile("/home/xxy/model/model.pb","rb") as f:
graph_def=tf.GraphDef()
graph_def.ParseFromString(f.read())
result=tf.import_graph_def(graph_def,return_elements=["add:0"])
print(sess.run(result))