Paddle.fluid之模型freeze

model.freeze

#-*- coding:utf-8 -*-
import os
import paddle 
import paddle.fluid as fluid
from model_zoos import antisp as model


def gen_exe(resume = '', use_cuda = False):

    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

    if resume:
        def if_exist(var):
            has_var = os.path.exists(os.path.join(resume, var.name))
            if has_var:
                print('var: %s found' % (var.name))
                return has_var
        fluid.io.load_vars(exe, resume, predicate=if_exist)
        print('init weight finished!')
    return exe


def freeze_model(exe, save_dir, tvars):

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    
    if save_dir is not None:
        fluid.io.save_inference_model(
            dirname = save_dir,
            feeded_var_names = ['image'], 
            target_vars = [tvars], 
            executor = exe,
            model_filename = 'model', 
            params_filename = 'params')
    print('paddle inference model saved')


if __name__ == "__main__":
    
    image = fluid.layers.data(name='image', shape=[1, 3, img_h, img_w], dtype='float32', append_batch_size=False)
    output, _, _ = model(image, True, 1.0)
    resume = 'checkpoints/param_weights/iter_xxx_x_eppch'
    save_dir = 'checkpoints/freeze_model'
    exe = gen_exe(resume)
    freeze_model(exe, save_dir, output)

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ReLuJie

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值