最近整理了一下这三个框架的模型格式与引入语句,可能有不太精准的地方,在2020最后一天先做一个简单的说明吧。欢迎大家指出问题,共同进步。
一.tensorflow模型格式
1. tensorflow 0.11以前的版本保存得到的模型为如下3个文件
checkpoint | 为一个文本文件,记录训练过程中中间节点上保存的模型名称,首行记录最近一次保存的模型名称 |
model.meta | meta文件保存完整的tensorflow图结构,所有的变量操作 |
model.ckpt | 实现.data与.meta合并的效果 |
2.tensorflow 0.11及以上版本保存如下4个文件
checkpoint | 该文件记录保存最新检查点文件 |
model.index | .index保存的是变量中key与value的对应关系 |
model.data | .data保存的是训练变量参数的值 |
model.meta | .meta保存的是完整的网络图结构 |
3.tensorflow2保存为savedmodel格式(文件夹),当模型具有自定义图层时,savedmodel是首选保存方法
----path_to_save_model
----assets 该文件夹通常为空,有时可能包含tensorflow内部用于还原模型的一些信息
----variables 该文件夹包含训练后模型权重与偏差,它包含大量变量,因此将其拆分为多个文件
----variables.data-00000-of-00002 参数值
----variables.data-00001-of-00002
----variables.index 参数
----saved_model.pb 该文件包含保存的模型架构信息,训练配置和优化器信息
4.tensorflow2保存为.h5格式(单独文件)
(1)模型结构,权重,配置,优化器状态整体存为.h5
(2)模型结构存为单独的文件,模型权重存为.h5文件
二.tensorflow1模型保存与载入语句
1.保存tensorflow1模型
tensorflow中,想保存所有参数的图形和值,将为其创建tf.train.saver()类的实例,tensorflow变量仅在会话内有效,所以需要通过在刚刚创建的saver对象上调用save方法将模型保存在会话中。
import tensorflow as tf
.....
.....
.....
saver=tf.train.Saver()#创建类的实例
sess=tf.Session()#创建会话
sess.run(tf.global_variables_initializer())#初始化所有变量
saver.save(sess,'mymodel')
2.导入预训练模型
(1)创建网络
a.通过编写python来创建网络,手动编写代码将每个图层创建为原始模型。
b.将网络保存在.meta中,使用tf.train.import()函数像下面这样来重新创建网络。
tf.train.import_meta_graph('mymodel.meta')
(2)加载参数:加载此图形上训练的参数值,通过在程序上调用restore来恢复网络参数,程序是tf.train.Saver()类的实例。
with tf.session() as sess:
new_saver=tf.train.import_meta_graph('mymodel.meta')
new_saver.restore(sess,tf.train.latest_checkpoint('./'))
#之后张量的值已经可以恢复并且访问
print(sess.run('w1:0'))
三.keras模型保存与与载入语句
1.保存模型整体为.h5文件并引入
model=tf.keras.applications.ResNet50(weights="imagenet")
model.summary()#查看网络结构
model.save("mymodel.h5")#保存模型权重与结构到一个.h5文件中
new_model=keras.models.load_model("mymodel.h5")#加载模型结构与权重
2.将权重保存为.h5,结构保存为json格式,分别引入
json_config=model.to_json()#获取model的网络结构
with open('model_config.json','w') as json_file:
json_file.write(json_config)#将网络结构文件写入“model_config.json"并保存
model.save_weights('mymodel.h5')#仅保存模型权重
with open('model_config.json') as json_file:
json_config=json_file.read()
new_model=keras.models.model_from_json(json_config)#加载保存的结构文件
new_model.load_weights('path_to_my_weights.h5')#加载保存的权重
3.将模型与权重保存为Savedmodel文件夹格式
model.save('wenjianjia',save_format='tf')#保存模型为SavedModel格式
newmodel=keras.models.load_model('wenjianjia')#引入savedmodel
四.pytorch模型保存与载入语句
1.模型结构与参数保存为一个整体
torch.save(model,'XX.pth')#首先将模型保存为pth文件
model=torch.load('Xx.pth')#载入保存的模型
pthfile=r'C:/cnn/me.pth'
net=torch.load(pthfile,map_location=torch.device('cpu'))
2.分别加载网络结构与参数
torch.save(my_resnet.state_dict(),"my_resnet.pth")#将my_resnet模型权重存储为my_resnet.pth
resnet.load_state_dict(torch.load('my_resnet.pth')#其中resnet是my_resnet.pth对应的网络结构
比较概括的先写到这里啦,明年可能会在补充些示例。