tensorflow1,tensoflow2+keras,pytorch模型格式与引入

最近整理了一下这三个框架的模型格式与引入语句,可能有不太精准的地方,在2020最后一天先做一个简单的说明吧。欢迎大家指出问题,共同进步。


一.tensorflow模型格式

1. tensorflow 0.11以前的版本保存得到的模型为如下3个文件

checkpoint 为一个文本文件,记录训练过程中中间节点上保存的模型名称,首行记录最近一次保存的模型名称
model.meta meta文件保存完整的tensorflow图结构,所有的变量操作
model.ckpt   实现.data与.meta合并的效果

2.tensorflow 0.11及以上版本保存如下4个文件

checkpoint 该文件记录保存最新检查点文件
model.index.index保存的是变量中key与value的对应关系
model.data .data保存的是训练变量参数的值
model.meta  .meta保存的是完整的网络图结构

3.tensorflow2保存为savedmodel格式(文件夹),当模型具有自定义图层时,savedmodel是首选保存方法

----path_to_save_model

   ----assets  该文件夹通常为空,有时可能包含tensorflow内部用于还原模型的一些信息

   ----variables  该文件夹包含训练后模型权重与偏差,它包含大量变量,因此将其拆分为多个文件

        ----variables.data-00000-of-00002  参数值

        ----variables.data-00001-of-00002

        ----variables.index  参数

    ----saved_model.pb  该文件包含保存的模型架构信息,训练配置和优化器信息

4.tensorflow2保存为.h5格式(单独文件)

(1)模型结构,权重,配置,优化器状态整体存为.h5

(2)模型结构存为单独的文件,模型权重存为.h5文件

二.tensorflow1模型保存与载入语句

1.保存tensorflow1模型

        tensorflow中,想保存所有参数的图形和值,将为其创建tf.train.saver()类的实例,tensorflow变量仅在会话内有效,所以需要通过在刚刚创建的saver对象上调用save方法将模型保存在会话中。

import tensorflow as tf
.....
.....
.....
saver=tf.train.Saver()#创建类的实例
sess=tf.Session()#创建会话
sess.run(tf.global_variables_initializer())#初始化所有变量
saver.save(sess,'mymodel')

2.导入预训练模型

(1)创建网络

a.通过编写python来创建网络,手动编写代码将每个图层创建为原始模型。

b.将网络保存在.meta中,使用tf.train.import()函数像下面这样来重新创建网络。

tf.train.import_meta_graph('mymodel.meta')

(2)加载参数:加载此图形上训练的参数值,通过在程序上调用restore来恢复网络参数,程序是tf.train.Saver()类的实例。

with tf.session() as sess:
    new_saver=tf.train.import_meta_graph('mymodel.meta')
    new_saver.restore(sess,tf.train.latest_checkpoint('./'))
#之后张量的值已经可以恢复并且访问
    print(sess.run('w1:0'))

三.keras模型保存与与载入语句

1.保存模型整体为.h5文件并引入

model=tf.keras.applications.ResNet50(weights="imagenet")

model.summary()#查看网络结构

model.save("mymodel.h5")#保存模型权重与结构到一个.h5文件中

new_model=keras.models.load_model("mymodel.h5")#加载模型结构与权重

2.将权重保存为.h5,结构保存为json格式,分别引入

json_config=model.to_json()#获取model的网络结构

with open('model_config.json','w') as json_file:

        json_file.write(json_config)#将网络结构文件写入“model_config.json"并保存

model.save_weights('mymodel.h5')#仅保存模型权重

with open('model_config.json') as json_file:

        json_config=json_file.read()

new_model=keras.models.model_from_json(json_config)#加载保存的结构文件

new_model.load_weights('path_to_my_weights.h5')#加载保存的权重

3.将模型与权重保存为Savedmodel文件夹格式

model.save('wenjianjia',save_format='tf')#保存模型为SavedModel格式

newmodel=keras.models.load_model('wenjianjia')#引入savedmodel

四.pytorch模型保存与载入语句

1.模型结构与参数保存为一个整体

torch.save(model,'XX.pth')#首先将模型保存为pth文件

model=torch.load('Xx.pth')#载入保存的模型

pthfile=r'C:/cnn/me.pth'

net=torch.load(pthfile,map_location=torch.device('cpu'))

2.分别加载网络结构与参数

torch.save(my_resnet.state_dict(),"my_resnet.pth")#将my_resnet模型权重存储为my_resnet.pth

resnet.load_state_dict(torch.load('my_resnet.pth')#其中resnet是my_resnet.pth对应的网络结构

 


比较概括的先写到这里啦,明年可能会在补充些示例。

 

 

 

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值