训练内容的解析
基于SMARTS(二).1 straight 环境 + agent (PPO)训练代码解读–环境配置训练启动后内部运行流程分析。
在smarts/benchmark路径下
python run.py scenarios/intersections/staright -f agents/ppo/baseline-continuous-control.yaml
PPO+SMARTS内部究竟在做什么?
- 解析外部的yaml文件。
yaml被解读成为一个多层级结构(具体格式尚不清楚,但是访问方式简单),比如结构如下的yaml
name: ppo-baseline
agent:
state:
wrapper:
name: Simple
features:
distance_to_center: True
speed: True
steering: True
heading_errors: [20, continuous]
neighbor: 8
action:
type: 1 # 0 for continuous, 1 for discrete
interface:
max_episode_steps: 1000
neighborhood_vehicles:
radius: 50
waypoints:
lookahead: 50 # larger than size of heading errors
policy:
framework: rllib
trainer:
path: ray.rllib.agents.ppo
name: PPOTrainer
run:
checkpoint_freq: 40
checkpoint_at_end: True
max_failures: 1000
resume: False
export_formats: [model, checkpoint]
stop:
time_total_s: 14400
config:
log_level: WARN
num_workers: 1
num_gpus: 0
horizon: 1000
clip_param: 0.2
num_sgd_iter: 10
sgd_minibatch_size: 32
lr: 1e-4
lambda: 0
解析方式为:
import yaml
with open(base_dir / config_file, "r") as f:
raw_config = yaml.safe_load(f)
raw_config就是一个多级结果,通过raw_config['agent']['state']['wrapper']['name']
方式进行访问具体的数值。
随后raw_config被送入benchmark/agents/init.py中的_make_rllib_config()进行解析并返回解析后的config。
接下来逐个 查看整个yaml的解析流程:
- action space
yaml中
action:
type: 1 # 0 for continuous, 1 for discrete
将上述离散值送入解析
from . import common
agent = config["agent"]
action_type = agent["action"]["type"]
common.ActionSpace.from_type(action_type)
解析流程如下:
class ActionSpace:
@staticmethod
def from_type(action_type: int):
space_type = ActionSpaceType(action_type)
if space_type == ActionSpaceType.Continuous:
return gym.spaces.Box(
low=np.array([0.0, 0.0, -1.0]),
high=np.array([1.0, 1.0, 1.0]),
dtype=np.float32,
)
elif space_type == ActionSpaceType.Lane:
return gym.spaces.Discrete(4)
else:
raise NotImplementedError
1 离散对应与ActionSpaceType.Lane
class ActionSpaceType(Enum):
Continuous = 0
Lane = 1
ActuatorDynamic = 2
LaneWithContinuousSpeed = 3
TargetPose = 4
Trajectory = 5
MultiTargetPose = 6 # for boid control
MPC = 7
所以上述解析返回的数值为gym.spaces.Discrete(4)
class Discrete(Space):
r"""A discrete space in :math:`\{ 0, 1, \\dots, n-1 \}`.
Example::
>>> Discrete(2)
"""
def __init__(self, n):
assert n >= 0
self.n = n
super(Discrete, self).__init__((), np.int64)
def sample(self):
return self.np_random.randint(self.n)
def contains(self, x):
if isinstance(x, int):
as_int = x
elif isinstance(x, (np.generic, np.ndarray)) and (x.dtype.char in np.typecodes['AllInteger'] and x.shape == ()):
as_int = int(x)
else:
return