【Spinning Up】六、spinup的run_utils,实现批量调参,极简模式
文章目录
前言:
关于这个批量调参的功能,spinup的官方文档就简单几句话:
Spinning Up ships with a tool called ExperimentGrid for making hyperparameter ablations easier. This is based on (but simpler than) the rllab tool called VariantGenerator.
主要有两个函数ExperimentGrid和Calling Experiments。
一个是网格参数搜索,一个是启动实验。
关于这个ExperimentGrid,其实不用那么复杂,简单点用一个列表迭代就可以了。
无法替代的是这个Calling Experiments,文档里介绍的是:
This wraps a few pieces of functionality which are useful when you want to run many experiments in sequence, including logger configuration and splitting into multiple processes for MPI.
这里提到,多进程是必须得用到这个。
其次如果用的是TensorFlow1,那么在单一进程,连续启动两次实验,会造成一个变量名冲突报错,具体的在我的这篇博客里有介绍:【Spinning Up】一文弄懂序列化模块json、pickle和cloudpickle
官方文档给的注释:
The way the experiment is actually executed is slightly complicated: the function is serialized to a string, and then run_entrypoint.py is executed in a subprocess call with the serialized string as an argument. run_entrypoint.py unserializes the function call and executes it. We choose to do it this way—instead of just calling the function directly here—to avoid leaking state between successive experiments.
他们自己也说这个实现有点复杂~要先把函数序列化成二进制字符串,然后用subprocess启动run_entrypoint.py脚本,将这个字符串通过命令行的方式传给脚本的sys.argv,被argparse接受。然后再run_entrypoint.py中再反序列化成函数,最后执行他。
选择这种方式是避免,连续实验的状态泄露!
我勒个去,一周之前我看到这段话,简直如同看天书~
看完了系列化和反序列化,以及subprocess,再加上多次debug,才搞明白,spinup到底在做什么…
最后话不多说,上例程,大伙儿应该可以快速利用到自己的代码中;
如果有什么更好更简洁的方案,也希望大家能够告知~
tune_funcs极简例程:
极简例程不涉及mpi,只用来表示子进程打包函数;
新建三个函数:
- tune_exps.py: 主函数
- tune_func.py: 待执行函数
- run_entrypoint.py: 执行隔离后的函数
tune_func.py: 待执行函数
先看tune_func.py,极度简单,就是传一个参数,打印一下进程号:
import os
def func(param=0):
print("param:", param)
print("pid", os.getpid())
print("-"*20)
if __name__=='__main__':
for i in range(2):
func()
run_entrypoint.py 入口函数
这个函数用的spinup自带的,也比较简单,就是一个接收命令行传参,一个反序列化过程,一个执行恢复后的函数的过程
import zlib
import pickle
import base64
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('encoded_thunk')
args = parser.parse_args()
thunk = pickle.loads(zlib.decompress(base64.b64decode(args.encoded_thunk)))
thunk()
tune_exps.py 调用主函数:
这里东西比较多,相比spinup原本的run_utils.py,我已经删到不能再删的地步了。
整个流程:
- 导入待执行函数
- 将批量参数放在列表中,如果一个不够,就多层循环套娃一下;
- 将函数和参数传入callexp()中;
- 用lambda语句,将函数和参数打包在一起;
- 将打包后的序列化成二进制字符串encoded_thunk ;
- 封装成命令行字符串列表:cmd = [‘python’, ‘run_entrypoint.py’, encoded_thunk]
- 利用subprocess.call_check(cmd)
- 完了~
import base64
from copy import deepcopy
import cloudpickle
import numpy as np
import os
import os.path as osp
import psutil
import string
import subprocess
from subprocess import CalledProcessError
import sys
from textwrap import dedent
import time
import zlib
# 导入待执行的函数
from tune_func import func
DIV_LINE_WIDTH = 80
def call_experiment(thunk, param, **kwargs):
"""
:param thunk:待启动的函数
:param param:批量参数名
:param kwargs: 其他的一些没考虑到的参数~用处不大,没事儿最好别写这个,容易造成混乱~
正常的函数,传入参数之后,就会直接执行。
但是通过这个神奇的lambda,就可以即把参数传进去,又不执行。返回出一个函数
再次调用的时候,只需要将返回值,加上括号,即当一个无参数传入的函数执行就可以了。
"""
lambda_thunk = lambda: thunk(param=param)
print("lambda_thunk:", lambda_thunk)
"""
下面的操作就有点抽象了,涉及到一个’序列化‘和‘反序列化’的操作;
先通过cloudpickle.dumps对函数thunk进行序列化,编码成一个二进制文件;
这个二进制文件可以通过pickle.loads解码成原本的函数,参数什么的都不变。
该是什么功能就是什么功能;
但为什么花这么大代价,整这一通操作呢?
因为如果在当前脚本下,连续跑一个DDPG-tf1,TD3-tf1,就会报错,因为当前进程里
之前DDPG的变量还没有清空,因此有如下报错:
ValueError: Variable main/pi/dense/kernel already exists, disallowed.
这个就很尴尬了,我是没有好的方案去清除,当前进程之前的那些变量;
因此只好将spinup自带的这个打包函数+单一超参数->启动子进程->解码->在子进程执行单一超参数的函数
这套操作涉及到不少知识点:
1.lambda处理函数; lambda本身就是返回一个未执行的函数;
2.pickle对python对象的编码和解码;这个得看我的博客
3.subprocess的启动子进程;这个随便搜一个博客就行了
4.argparse和命令行的sys.argv的交互;这个我的博客刚写的
5.子进程和父进程的关系;这个没什么好说的,知道有这么回事儿就行了。
6.如果要用mpi_fork()的话,那么必定不能在当前脚本下,多次启动。
"""
pickled_thunk = cloudpickle.dumps(lambda_thunk)
encoded_thunk = base64.b64encode(zlib.compress(pickled_thunk)).decode('utf-8')
# 当前脚本和entry_point.py的路径要在一起,要不然下面的语句要改。
entrypoint = osp.join(osp.abspath(osp.dirname(__file__)), 'run_entrypoint.py')
# subprocess的输入就是一个字符串列表,正常在命令行,该怎么输入,这个就该怎么写。
cmd = [sys.executable if sys.executable else 'python', entrypoint, encoded_thunk]
print("tune_exps_pid:", os.getpid())
try:
subprocess.check_call(cmd, env=os.environ)
except CalledProcessError:
err_msg = '\n'*3 + '='*DIV_LINE_WIDTH + '\n' + dedent("""
Check the traceback above to see what actually went wrong.
""") + '='*DIV_LINE_WIDTH + '\n'*3
print(err_msg)
raise
if __name__ == '__main__':
params_list = [2, 3, 4, ]
for param in params_list:
call_experiment(thunk=func, param=param)
打印结果:
可以看到主进程ID一直不变,子进程的ID一直更新。参数传输正常,一切OK。
lambda_thunk: <function call_experiment.<locals>.<lambda> at 0x7f0ca65da290>
tune_exps_pid: 18451
param: 2
pid 18482
--------------------
lambda_thunk: <function call_experiment.<locals>.<lambda> at 0x7f0ca65da290>
tune_exps_pid: 18451
param: 3
pid 18483
--------------------
lambda_thunk: <function call_experiment.<locals>.<lambda> at 0x7f0ca65da290>
tune_exps_pid: 18451
param: 4
pid 18484
--------------------
联系方式:
ps: 欢迎做强化的同学加群一起学习:
深度强化学习-DRL:799378128
欢迎关注知乎帐号:未入门的炼丹学徒
CSDN帐号:https://blog.csdn.net/hehedadaq
极简spinup+HER+PER代码实现,两小时配置完毕:https://github.com/kaixindelele/DRLib