paddle静态图模型的保存和加载

##静态图模型的保存与加载

"""
在静态图中,模型结构部分可以转化为可持久化的程序
"""

import paddle
import paddle.static as static

#开启静态图模式
paddle.enable_static()

#创建输入数据和网络
x=paddle.static.data(name='x',shape=[None,224],dtype='float32')
z=paddle.static.nn.fc(x,10)

print(x)
print(z)
#设置执行器开始训练
place=paddle.CPUPlace()
exe=paddle.static.Executor(place)
print(exe)
exe.run(paddle.static.default_startup_program())
prog=paddle.static.default_main_program()

##保存模型的参数
paddle.save(prog.state_dict(),'temp/model.pdparms')
print('静态图模型参数',prog.state_dict(),len(prog.state_dict()))
#载入模型参数
state_dict=paddle.load('temp/model.pdparms')
prog.set_state_dict(state_dict)
print('导入的模型参数',state_dict)

##保存整个静态图模型(包含静态图结构和参数)
paddle.save(prog,'temp/model.pdmodel')
print('整个静态图模型',prog)
#导入整个模型结构
prog=paddle.load('temp/model.pdmodel')
print('导入的模型结构',prog)



import paddle
from paddle.static import InputSpec
import numpy as np

"""
静态图
"""
paddle.enable_static()
# 定义输入数据的规格
input_spec = [InputSpec(shape=[None, 3, 224, 224], dtype='float32', name='image')]
# 加载推理模型
# 注意:将'model_dir'替换为你的模型文件所在的目录,'model_filename'和'params_filename'分别为模型结构和参数的文件名
[inference_program, feed_target_names, fetch_targets] = paddle.static.load_inference_model(
    path_prefix='./inference',
    model_filename='inference.pdmodel',
    params_filename='inference.pdiparams',
    executor=paddle.static.Executor()
)
# 准备输入数据
input_data = np.random.randn(1,3,224,224).astype(np.float32) # 这里应该是你的输入数据,例如一个NumPy数组
# 创建执行器
place = paddle.CPUPlace()  # 或者使用paddle.CUDAPlace(0)如果你在GPU上运行
exe = paddle.static.Executor(place)

# 执行推理
results = exe.run(program=inference_program,
                  feed={feed_target_names[0]: input_data},
                  fetch_list=fetch_targets)

# 处理输出结果
# results是一个列表,包含了fetch_targets中指定的输出变量的值
# print(results)

import paddle

# 加载参数
params_dict = paddle.load('./inference/inference.pdiparams.info')
print(params_dict)
print('*'*50)
# 打印参数字典的键,这些键通常是参数的名称
for key in params_dict:
    print(key)

"""
动态图
"""

# 加载模型参数
model_dict = paddle.load('./hh/latest.pdparams')
for k,v in model_dict.items():
    print(k,v.shape)
# print(model_dict)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一壶浊酒..

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值