详解常见的tensorflow,keras模型保存和加载

 一定要看代码注释,这些模型格式可以互相转化,有兴趣自己查阅学习!

tensorflow保存的文件格式多种多种:TFLite, frozen graph, SavedModel, serving model, TFHub representation, Keras's .h5  .其实tensorflow和keras是一家,可以理解为tensorflow是c,keras则是python ,就是封装成直接调用的简单方法成了keras了。

Tensorflow

1、CheckPoint(.ckpt) :.ckpt方式保存模型,这种模型文件是依赖 TensorFlow 的,只能在其框架下使用(我很少使用这种,初学感觉文件好多,也不好理解)

参考TensorFlow模型保存和提取方法

保存:

利用tf.train.Saver类实现模型的保存和加载,直接上代码!!!



import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, "model.ckpt")

这种方式model.ckpt生成4个文件:如图

checkpoint文件:b包含最新的和所有的文件地址
.data文件:包含训练变量的文件(就是神经网络中权重w以及一些其他变量)
.index文件:描述variable中key和value的对应关系
.meta文件:保存完整的网络图结构(也就是graph,你定义好的整个模型运行的框架)
使用这种方法保存模型时会保存成上面这四个文件

加载:

重新加载模型时通常只会用到.meta文件恢复图结构然后用.data文件把各个变量的值再加进去。



with tf.Session() as sess:
    #恢复计算图结构
    saver = tf.train.import_meta_graph(',model.ckpt.meta')

     #恢复所有变量信息
    saver.restore(sess, "model.ckpt")

#现在sess中已经恢复了网络结构和变量信息了,之后直接用节点的名称来调用,相关代码自己查吧!

2..GraphDef(.pb):这种格式一个文件就保存图的结构和变量,方便使用。

代码不完整,参考tensorflow保存模型的几种方法

保存:

tf.train.write_graph(') # 生成.pb, 再通过freeze_graph把.pb与ckpt固化成新的pb文件


# 把变量转成常量之后写入PB文件中
def SaveFrozenPb(nodeNameList, pbFile):
    gd = tf.graph_util.convert_variables_to_constants(sess,
                                        tf.get_default_graph().as_graph_def(),nodeNameList)
    with tf.gfile.GFile(pbFile, 'wb') as f:
        f.write(gd.SerializeToString())
# 通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件
# freeze_graph --input_graph=./hello.pb --input_checkpoint=./hello_model --output_node_names=hello,y --input_node_names=x --output_graph=./hello_frozen.pb
# 如果不调用freeze_graph, 直接使用会报错‘google.protobuf.message.DecodeError: Error parsing message’ 

def SavePbForFreezeGraph(pbDir, pbName):
    tf.train.write_graph(sess.graph_def, pbDir, pbName)



加载:


def RestorePb(sess, name):
    # 二进制读取模型文件
    with tf.gfile.FastGFile(name, 'rb') as f:                     
       graph_def = tf.GraphDef()                     
       graph_def.ParseFromString(f.read())                     
       sess.graph.as_default()                     
       tf.import_graph_def(graph_def, name='') # 导入计算图

if __name__ == '__main__':
    sess = tf.Session()
     RestorePb(sess, './hello_frozen.pb')

Keras

将tensorflow封装起来好多方法形成了keras,它的母体还是tensorflow,所以keras使用起来更方便

1..h5

保存:

keras的模型一般保存为后缀名为h5的文件,但是h5文件用save()和save_weight()保存效果是不一样的。这里我使用其他博主的代码:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

一共保存了m1.h5, m2.h5, m3.h5 这三个h5文件。

m2表示save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。它的size最大的。

m1表示save()保存的训练前的模型结果,它保存了模型的图结构,但应该没有保存模型的初始化参数,它的size要比m2小很多。

m3表示save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。它的size也要比m2小很多。

加载:

加载模型:

from tensorflow.python.keras.models import load_model

model.load_model(m2.h5)
model.load_weights(‘m3.h5’)

如果需要加载权重到不同的网络结构中,可以通过层名字来加载模型: 
model.load_weights('my_model_weights.h5', by_name=True)

2.pb

因为 .h5格式的模型只适合在本地使用,不适合部署。tensorflow2.x 中keras提供了保存和加载 .pb 格式模型的方法,很简单。

# 保存模型结构和参数到文件
tf.keras.models.save_model(net,"model_save_path") # 默认生成 .pb 格式模型,也可以通过save_format 设置 .h5 格式
print('模型已保存')
# 加载
net=tf.keras.models.load_model("model_save_path")


 

  • 1
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值