Tensorflow学习笔记十三——模型持久化

13.1 典型的模型保存方法

  • train.Saver类是Tensorflow1.x自己OMG提供的用于保存和还原一个神经网络模型的低阶API。
 
 import tensorflow as tf
import numpy as np

a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")
result=a+b

init_op=tf.initialize_all_variables()
saver=tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    

该方法会产生4个文件,其中

  • checkpoint文件是一个文本文件,保存了一个目录下所有模型文件列表。checkpoint文件会被自动更新。

  • model.ckpt.data-00000-of-00001文件保存了Tensorflow程序中每一个变量的取值

  • model.ckpt.index文件保存了每一个变量的名称,是一个string-string的table,其中tabe的key值为tensor名,value值为BundleEntryProto

  • model.ckpt.meta文件保存了计算图的结构,或者说是神经网络的结构

  • restore()函数需要在模型参数恢复前定义计算图上的所有运算,并且变量名需要与模型中存储的变量名一致,这样就可以将变量的值通过已保存的模型加载进来。

import tensorflow as tf
import numpy as np

a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")
result=a+b

saver=tf.train.Saver()

with tf.Session() as sess:
   saver.restore(sess,"/home/xxy/model/model/ckpt")
   print(sess.run(result))
  • import_meta_graph()直接加载已经持久化的计算图。其输入参数为一个.meta文件的路径。它返回一个Saver实例,在调用restore()函数就可以回复其参数了。
 import tensorflow as tf
meta_graph=tf.train.import_meta_graph("/home/xxy/model/model.ckpt.meta")
with tf.Session() as sess:
    meta_graph.restore(sess,"/home/xxy/model/model.ckpt")
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
  • 保存和加载部分变量
 import tensorflow as tf
a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")
result=a+b

saver=tf.train.Saver([a])

with tf.Session() as sess:
    saver.restore(sess,"/home/xxy/model/model.ckpt")
    print(sess.run(a))

保存部分变量也可以通过在声明train.Saver类的同时提供一个列表的方式来指定。

  • 保存或加载时给变量重新命名。
 import tensorflow as tf
a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a2")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b2")
result=a+b

saver=tf.train.Saver({"a":a,"b":b})

with tf.Session() as sess:
    saver.restore(sess,"/home/xxy/model/model/ckpt")
    print(sess.run(result))
  • 滑动平均变量保存方式

writer

import tensorflow as tf
a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")

##滑动平均变量定义
averages_class=tf.train.ExponentialMovingAverage(0.99)
averages_op=averages_class.apply(tf.all_variables())

for variables in tf.global_variables():
    print(variables.name)    
    
init_op=tf.global_variables_initializer()
saver=tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    #assign()更新变量值
    sess.run(tf.assign(a,10))
    sess.run(tf.assign(b,5))
    
    sess.run(averages_op)
    saver.save(sess,"/home/xxy/model/model2.ckpt")
    print(sess.run([a,averages_class.average(a)]))
    print(sess.run([b,averages_class.average(b)]))

reader

import tensorflow as tf
a=tf.Variable(tf.constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(tf.constant([3.0,4.0],shape=[2]),name="b")

##滑动平均变量定义
averages_class=tf.train.ExponentialMovingAverage(0.99)

saver=tf.train.Saver({"a/ExponentialMovingAverage":a,
                      "b/ExponentialMovingAverage":b})
#train.ExponentialMovingAverage 提供的variables_to_restore()函数直接生成上面代码中提供的字典,与上面的代码功能相同
'''
saver=tf.train.Saver(averages_class.variables_to_restore())
'''

with tf.Session() as sess:
    sess.restore(sess,"/home/xxy/model/model2.ckpt")
    print(sess.run([a,b]))
    print(averages_class.variables_to_restore())

13.2 模型持久化的原理
1.model.ckpt.mate文件
model.ckpt.mate文件存储的是Tensorflow程序的元图数据。就是计算图的节点信息。元图数据的存储格式为MetaGraphDef。

message MetaGraphDef{
    MetaInfoDef meta_info_def=1;
    GraphDef graph_def=2;
    SaveDef saver_def=3;
    map<string,CollectionDef> collection_def=4;
    map<string,SIgnatureDef> signature_def=5;
};
import tensorflow as tf
a=tf.Variable(constant([1.0,2.0],shape=[2]),name="a")
b=tf.Variable(constant([3.0,4.0],shape=[2]),name="b")
result=a+b
saver=tf.train.Saver()
asver.export_meta_graph("/home/xxy/model_ckpt_meta_json",as_text=True)
  • meta_info_def属性
 message MetaInfoDef{
    string meta_graph_version=1;
    OpList stripped_op_list = 2;
    google.protobuf.Any_any_info=3;
    repeated string tags=4;
};

其中stripped_op_list记录了计算图中用到的所有运算方法信息。记录了OpDef型的op属性

attr{
    name:"T"
    type:"type"
    allow_values{
        list{
        type:DT_HALF
        type:DT_FLOAT
        type:DT_DOUBLE
        type:DT_UINT8
        type:DT_INT8
        type:DT_INT16
        type:DT_INT32
        type:DT_INT64
        type:DT_COMPLEX64
        tpe:DT_COMPLEX128
        type:DT_STRING
    }
    }
}
  • graph_def属性
 message GrapDef{
    repeated NodeDef node=1;
    VersionDef versions=4;
};
message NodeDef{
    string name=1;
    string op=2;
    repeated string input=3;
    string device =4;
    map<string,AttrValue> attr =5;
};
  • saver_def属性
message SaveDef{
    string filename_tensor_name =1;
    string save_tensor_name = 2;
    string restore_op_name =3;
    int32 max_to_keep=4;
    bool shared =5;
    float keep_checkpoint_every_n_hours=6;
    
    enum CheckpointFormatVersion{
    LEGACY=0;
    V1=1;
    V2=2;
}
    CheckpointFormatVersion=7;
};
save_def{
    filename_tensor_name:"save/Const:0"
    save_tensor_name:"save/control_dependency:0"
    restore_op_name:"save/restore_all"
    max_to_keep:5
    keep_checkpoint_every_n_hours:10000.0
    version:V2
};
  • collection_def属性
 message CollectionDef{
    message NodeList{
        repeated string value=1;
    }
    message BytesList{
        repeated bytes value=1;
    }
    message Int64List{
        repeated int64 value=1[packed=true];
    }
    message {
        repeated floFloatListat value =1;
    }
    message AnyList{
        repeated google.protocolbuf.Any value =1;
    }
    oneof kind{
        NodeList node_list=1;
        BytesList bytes_list=2;
        Int64List int64_list=3;
        FLoatList float_list=4;
        AnyList any_list=5;
    }   
};

2.从.index和.data文件中读取变量的值

import tensorflow as tf
reader=tf.train.NewCheckpointReader("/home/xxy/model/model.ckpt")
all_variables=reader.get_variable_to_shape_map()
print(all_variables)

for variale_name in all_variables:
    print(variale_name,"shape is:",all_variables[variale_name])
print("Value for variable a is : ",reader.get_tensor("a"))
print("Value for variable b is : ",reader.get_tensor("b"))

13.3 在Tensorflow 2.0中实现模型保存
save

import tensorflow as tf
from tensorflow.keras import layers

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()

train_images=train_images.reshape(60000,748).astype('float32')/255
test_images=test_images.reshape(10000,748).astype('float32')/255

inputs=tf.keras.Input(shape=(784,),name='digits')

x=layers.Dense(500,activation='relu',name='dense_1')(inputs)

outputs=layers.Dense(10,activation='softmax',name='predictions')(x)

mlpmodel=tf.keras.Model(inputs=inputs,outputs=outputs,name='MLPModel')

model.summary()

mlpmodel.comploe(loss='sparse_categorical_crossentropy',optimizer=tf.keras.optimizers.SGD(),metrics=['accuracy'])
mlpmodel.fix(x=train_images,y=train_labels,epochs=10,batch_size=100,validation_data=(test_images,test_labels))
mlpmodel.save("/home/xxy/model/model.h5")

read

import tensorflow as tf
import numpy as np

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()

test_images=test_images.reshape(10000,784).astype('float32')/255

load_mlpmodel=tf.keras.models.load_model("/home/xxy/model/model.h5")
class_name=['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankleboot']
predictions=load_mlpmodel.predict(test_images)
for i in range(100):
    print("predict class result is : ",class_name[np.argmax(predictions[i])])
    print("crrect class result is : ", class_name[test_labels[i]])

save

import tensorflow as tf
from tensorflow.keras import layers

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()

train_images=train_images.reshape(60000,748).astype('float32')/255
test_images=test_images.reshape(10000,748).astype('float32')/255

inputs=tf.keras.Input(shape=(784,),name='digits')

x=layers.Dense(500,activation='relu',name='dense_1')(inputs)

outputs=layers.Dense(10,activation='softmax',name='predictions')(x)

mlpmodel=tf.keras.Model(inputs=inputs,outputs=outputs,name='MLPModel')

model.summary()

mlpmodel.comploe(loss='sparse_categorical_crossentropy',optimizer=tf.keras.optimizers.SGD(),metrics=['accuracy'])
mlpmodel.fix(x=train_images,y=train_labels,epochs=10,batch_size=100,validation_data=(test_images,test_labels))
tf.keras.experimental.export_saved_model(mlpmodel,"/home/xxy/model/")

read

import tensorflow as tf
import numpy as np

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()

test_images=test_images.reshape(10000,784).astype('float32')/255

load_mlpmodel=tf.keras.experimental.load_from_saved_model("/home/xxy/model/")
class_name=['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankleboot']
predictions=load_mlpmodel.predict(test_images)
for i in range(100):
    print("predict class result is : ",class_name[np.argmax(predictions[i])])
    print("crrect class result is : ", class_name[test_labels[i]])
config=mlpmode.get_config()
load_mlpmodel=tf.keras.from_config(config)

save_weight

import tensorflow as tf
from tensorflow.keras import layers

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()

test_images=test_images.reshape(10000,784).astype('float32')/255

class MLPModel(tf.keras.Model):
    def __init__(self,name=None):
        super(MLPModel,self).__init__(name=name)
        self.dense=layers.Dense(500,activation='relu',name='dense')
        self.dense_1=layers.Dense(10,activation='softmax',name='dense_1')
    def call(self,inputs):
        x=self.dense(inputs)
        return self.dense_1(x)
    
mlpmodel=MLPModel()
mlpmodel.compile(loss='sparse_categorical_crossentropy',optimzer=tf.keras.optimizers.SGD())
history=mlpmodel.fit(train_images,train_labels,batch_size=100,epochs=10)

mlpmodel.save_weights("/home/xxy/model/",save_format='tf')

read_weight

import tensorflow as tf
from tensorflow.keras import layers
import numpy as np

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.fashi_mnist.load_data()

test_images=test_images.reshape(10000,784).astype('float32')/255

class MLPModel(tf.keras.Model):
    def __init__(self,name=None):
        super(MLPModel,self).__init__(name=name)
        self.dense=layers.Dense(500,activation='relu',name='dense')
        self.dense_1=layers.Dense(10,activation='softmax',name='dense_1')
    def call(self,inputs):
        x=self.dense(inputs)
        return self.dense_1(x)

new_model=MLPModel()
new_model.compile(loss='sparse_categorical_crossentropy',optimzer=tf.keras.optimizers.SGD())
new_model.load_weights('/home/xxy/model/')

class_name=['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankleboot']
new_predictions=new_model.predict(test_images)
for i in range(100):
    print("predict class result is : ",class_name[np.argmax(new_predictions[i])])
    print("crrect class result is : ", class_name[test_labels[i]])

13.4 PB文件
writer

import tensorflow as tf
tensorflow/python/framework/graph_util.py
fromg tensorflow.python.framework import graph_util

a=tf.Variable(tf.constant([1.0,shape=[1]),name="a")
b=tf.Variable(tf.constant([3.0,shape=[1]),name="b")
result=a+b
init_op=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    
    graph_def=tf.get_defalut_graph().as_graph_def()
    
    output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,['add'])
    with tf.gfile.GFile("/home/xxy/model/model.pb","wb") as f:
        f.write(output_graph_def.SerializeToString())

reader

import tensorflow as tf
fromg tensorflow.python.platform import gfile

with tf.Session() as sess:
    with tf.gfile.FatGFile("/home/xxy/model/model.pb","rb") as f:
        graph_def=tf.GraphDef()
        graph_def.ParseFromString(f.read())
        
    result=tf.import_graph_def(graph_def,return_elements=["add:0"])
    
    print(sess.run(result))
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值