PaddlePaddle如何在两个program之间copy参数

  • 问题描述:在Fluid版本的PaddlePaddle中,如何实现在两个program之间copy参数?我通过下面代码尝试在两个program之间copy参数,但并没有骑到作用。

  • 问题复现:

import paddle.fluid as fluid
from rllab import lab
from rllab.utils import logger
import numpy as np

x = lab.data(name='x', shape=[5], dtype='float32')

policy_program = fluid.default_main_program().clone()
with fluid.program_guard(policy_program):
    with lab.variable_scope('policy'):
        y1 = lab.FullyConnected('fc', x, 10)

    vars = fluid.default_main_program().list_vars()
    policy_vars = filter(lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars)
    for each in policy_vars:
        logger.info(each.name)

value_program = fluid.default_main_program().clone()
with fluid.program_guard(value_program):
    with lab.variable_scope('value'):
        y2 = lab.FullyConnected('fc', x, 10)
    vars = fluid.default_main_program().list_vars()
    value_vars = filter(lambda x: 'GRAD' not in x.name and 'value' in x.name, vars)


policy_vars.sort(key=lambda x:x.name)
value_vars.sort(key=lambda x:x.name)
sync_program = fluid.default_main_program().clone()
with fluid.program_guard(sync_program):
    sync_ops = []
    for i, var in enumerate(policy_vars):
        logger.info("[assign] policy:{}   value:{}".format(policy_vars[i].name, value_vars[i].name))
        sync_op = lab.assign(policy_vars[i], value_vars[i])
        sync_ops.append(sync_op)
    sync_program = sync_program.prune(sync_ops)

exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
w0 = fluid.global_scope().find_var("policy/fc_W").get_tensor()
w1 = fluid.global_scope().find_var("value/fc_W").get_tensor()
print np.sum(np.array(w0))
print np.sum(np.array(w1))

exe.run(sync_program)
print '-------------'

w0 = fluid.global_scope().find_var("policy/fc_W").get_tensor()
w1 = fluid.global_scope().find_var("value/fc_W").get_tensor()
print np.sum(np.array(w0))
print np.sum(np.array(w1))
  • 输出信息:
-2.7668757
1.8978335
I0530 11:36:10.023646 294248 executor.cc:114] Create Variable feed global, which pointer is 0x5abd9e50
I0530 11:36:10.023655 294248 scope.cc:56] Create variable value/fc_W
I0530 11:36:10.023659 294248 executor.cc:119] Create Variable value/fc_W locally, which pointer is 0x5abdcbf0
I0530 11:36:10.023664 294248 executor.cc:114] Create Variable fetch global, which pointer is 0x5abd9df0
I0530 11:36:10.023669 294248 scope.cc:56] Create variable value/fc_b
I0530 11:36:10.023672 294248 executor.cc:119] Create Variable value/fc_b locally, which pointer is 0x5abd9cd0
I0530 11:36:10.023681 294248 executor.cc:334] CPUPlace Op(assign), inputs:{X[policy/fc_W[5, 10]({})]}, outputs:{Out[value/fc_W[0]({})]}.
I0530 11:36:10.023689 294248 tensor_util.cu:24] TensorCopy 5, 10 from CPUPlace to CPUPlace
I0530 11:36:10.023697 294248 tensor_util.cu:40] TensorCopy Done
I0530 11:36:10.023705 294248 executor.cc:334] CPUPlace Op(assign), inputs:{X[policy/fc_b[10]({})]}, outputs:{Out[value/fc_b[0]({})]}.
I0530 11:36:10.023710 294248 tensor_util.cu:24] TensorCopy 10 from CPUPlace to CPUPlace
I0530 11:36:10.023715 294248 tensor_util.cu:40] TensorCopy Done
I0530 11:36:10.023720 294248 scope.cc:40] Destroy variable value/fc_b
I0530 11:36:10.023726 294248 scope.cc:40] Destroy variable value/fc_W
-------------
-2.7668757 
1.8978335
  • 解决方法:

观察上述提供的代码,发现没有在sync_prgram中使用create_parameter(),该方法会创建一个参数,该参数是一个可学习的变量,即它是可以随着模型的训练而变化的,而且该方法是一个比较低级的API,通常在自定义运算符时使用,完整的使用为paddle.fluid.layers.create_parameter()

该方法更多内容,请参考PaddlePaddle API文档:http://www.paddlepaddle.org/documentation/api/zh/1.0.0/layers.html

  • 问题拓展:
    代码中使用了rllab,这里简单介绍一下rllab,rllab是一个研究强化学习算法的框架。官方网站为https://github.com/openai/rllab。官方支持python 3.5+,基于Theano。与OpenAI Gym的区别在于OpenAI Gym支持更广泛的环境,且提供在线的scoreboard可以用于共享训练结果。rllab自己也提供一个基于pygame的可视环境,同时它也可兼容OpenAI Gym。除此之外,它提供了一些强化学习算法的实现,这些参考实现和一些组件可以使得强化学习算法的开发更快上手。安装步骤可按照官方网站:https://rllab.readthedocs.io/en/latest/user/installation.html。
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值